{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# LANGUAGE RankNTypes #-}
module Data.Functor.Polynomial(
  Poly(..)
) where

import Data.Functor.Classes ( Eq1(..), eq1, Ord1(..), compare1 )
import Data.Kind (Type)
import Data.Type.Equality ((:~:)(..))

import GHC.TypeNats hiding (SNat)
import GHC.TypeNats.Extra
import qualified Data.Vector.Sized as SV

import Data.GADT.Compare
import Data.Finitary
import Data.GADT.HasFinitary

-- | Uniformly represented polynomial functor.
--
--   Given a @'HasFinitary' tag@ instance, @Poly tag@ is a polynomial functor.
--   When @tag@ is the inverse images of @α :: U -> Type@, @Poly@ is isomorphic to:
--
--   > Poly tag x
--   >  = ∑{n :: Type} (tag n, x^n)
--   >  = ∑{n :: Type} ∑{u :: U, α(u) == n} x^n
--   >  = ∑{u :: U} x^(α(u))
type Poly :: (Type -> Type) -> Type -> Type
data Poly tag x where
  P :: tag n -> (n -> x) -> Poly tag x

deriving instance Functor (Poly tag)

instance HasFinitary tag => Foldable (Poly tag) where
  null :: forall a. Poly tag a -> Bool
null (P tag n
tag n -> a
_) = case tag n -> SNat (Cardinality n)
forall (tag :: * -> *) a.
HasFinitary tag =>
tag a -> SNat (Cardinality a)
toSNat tag n
tag of 
    SNat (Cardinality n)
Zero -> Bool
True
    SNat (Cardinality n)
_    -> Bool
False
  length :: forall a. Poly tag a -> Int
length (P tag n
tag n -> a
_) = Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SNat (Cardinality n) -> Natural
forall (n :: Natural). SNat n -> Natural
fromSNat (tag n -> SNat (Cardinality n)
forall (tag :: * -> *) a.
HasFinitary tag =>
tag a -> SNat (Cardinality a)
toSNat tag n
tag))

  foldMap :: forall m a. Monoid m => (a -> m) -> Poly tag a -> m
foldMap a -> m
f (P tag n
tag n -> a
rep) = (n -> m) -> [n] -> m
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (a -> m
f (a -> m) -> (n -> a) -> n -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. n -> a
rep) (tag n -> [n]
forall (tag :: * -> *) a. HasFinitary tag => tag a -> [a]
toInhabitants tag n
tag)

instance HasFinitary tag => Traversable (Poly tag) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Poly tag a -> f (Poly tag b)
traverse a -> f b
f (P tag n
tag n -> a
rep) = tag n -> (n -> b) -> Poly tag b
forall (tag :: * -> *) n x. tag n -> (n -> x) -> Poly tag x
P tag n
tag ((n -> b) -> Poly tag b) -> f (n -> b) -> f (Poly tag b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> tag n -> (Finitary n => f (n -> b)) -> f (n -> b)
forall n r. tag n -> (Finitary n => r) -> r
forall (tag :: * -> *) n r.
HasFinitary tag =>
tag n -> (Finitary n => r) -> r
withFinitary tag n
tag ((a -> f b) -> (n -> a) -> f (n -> b)
forall n (g :: * -> *) a b.
(Finitary n, Applicative g) =>
(a -> g b) -> (n -> a) -> g (n -> b)
traverseFiniteFn a -> f b
f n -> a
rep)

traverseFiniteFn :: (Finitary n, Applicative g) => (a -> g b) -> (n -> a) -> g (n -> b)
traverseFiniteFn :: forall n (g :: * -> *) a b.
(Finitary n, Applicative g) =>
(a -> g b) -> (n -> a) -> g (n -> b)
traverseFiniteFn a -> g b
f n -> a
fromN = Vector (Cardinality n) b -> n -> b
forall {a} {a}. Finitary a => Vector (Cardinality a) a -> a -> a
indexByN (Vector (Cardinality n) b -> n -> b)
-> g (Vector (Cardinality n) b) -> g (n -> b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> g b)
-> Vector Vector (Cardinality n) a -> g (Vector (Cardinality n) b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b)
-> Vector Vector (Cardinality n) a
-> f (Vector Vector (Cardinality n) b)
traverse a -> g b
f ((Finite (Cardinality n) -> a) -> Vector Vector (Cardinality n) a
forall (n :: Natural) a.
KnownNat n =>
(Finite n -> a) -> Vector n a
SV.generate (n -> a
fromN (n -> a)
-> (Finite (Cardinality n) -> n) -> Finite (Cardinality n) -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Finite (Cardinality n) -> n
forall a. Finitary a => Finite (Cardinality a) -> a
fromFinite))
  where
    indexByN :: Vector (Cardinality a) a -> a -> a
indexByN Vector (Cardinality a) a
xs a
k = Vector (Cardinality a) a -> Finite (Cardinality a) -> a
forall (n :: Natural) a. Vector n a -> Finite n -> a
SV.index Vector (Cardinality a) a
xs (a -> Finite (Cardinality a)
forall a. Finitary a => a -> Finite (Cardinality a)
toFinite a
k)

instance (Eq x, GEq tag, HasFinitary tag) => Eq (Poly tag x) where
  == :: Poly tag x -> Poly tag x -> Bool
(==) = Poly tag x -> Poly tag x -> Bool
forall (f :: * -> *) a. (Eq1 f, Eq a) => f a -> f a -> Bool
eq1

instance (GEq tag, HasFinitary tag) => Eq1 (Poly tag) where
  liftEq :: forall a b. (a -> b -> Bool) -> Poly tag a -> Poly tag b -> Bool
liftEq a -> b -> Bool
eq = Poly tag a -> Poly tag b -> Bool
eqP
    where
      eqP :: Poly tag a -> Poly tag b -> Bool
eqP (P tag n
tag n -> a
rep) (P tag n
tag' n -> b
rep') = case tag n -> tag n -> Maybe (n :~: n)
forall a b. tag a -> tag b -> Maybe (a :~: b)
forall k (f :: k -> *) (a :: k) (b :: k).
GEq f =>
f a -> f b -> Maybe (a :~: b)
geq tag n
tag tag n
tag' of
        Maybe (n :~: n)
Nothing -> Bool
False
        Just n :~: n
Refl -> (n -> Bool) -> [n] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\n
i -> n -> a
rep n
i a -> b -> Bool
`eq` n -> b
rep' n
n
i) (tag n -> [n]
forall (tag :: * -> *) a. HasFinitary tag => tag a -> [a]
toInhabitants tag n
tag)

-- | **Does not preserve** the order through 'toPoly' and 'fromPoly'
instance (Ord x, GCompare tag, HasFinitary tag) => Ord (Poly tag x) where
  compare :: Poly tag x -> Poly tag x -> Ordering
compare = Poly tag x -> Poly tag x -> Ordering
forall (f :: * -> *) a. (Ord1 f, Ord a) => f a -> f a -> Ordering
compare1

-- | **Does not preserve** the order through 'toPoly' and 'fromPoly'
instance (GCompare tag, HasFinitary tag) => Ord1 (Poly tag) where
  liftCompare :: forall a b.
(a -> b -> Ordering) -> Poly tag a -> Poly tag b -> Ordering
liftCompare a -> b -> Ordering
cmpX = Poly tag a -> Poly tag b -> Ordering
cmpP
    where
      cmpP :: Poly tag a -> Poly tag b -> Ordering
cmpP (P tag n
tag n -> a
rep) (P tag n
tag' n -> b
rep') = case tag n -> tag n -> GOrdering n n
forall a b. tag a -> tag b -> GOrdering a b
forall k (f :: k -> *) (a :: k) (b :: k).
GCompare f =>
f a -> f b -> GOrdering a b
gcompare tag n
tag tag n
tag' of
        GOrdering n n
GLT -> Ordering
LT
        GOrdering n n
GEQ -> (n -> Ordering -> Ordering) -> Ordering -> [n] -> Ordering
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\n
i Ordering
r -> n -> a
rep n
i a -> b -> Ordering
`cmpX` n -> b
rep' n
n
i Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
<> Ordering
r) Ordering
EQ ([n] -> Ordering) -> [n] -> Ordering
forall a b. (a -> b) -> a -> b
$ tag n -> [n]
forall (tag :: * -> *) a. HasFinitary tag => tag a -> [a]
toInhabitants tag n
tag
        GOrdering n n
GGT -> Ordering
GT