{-# LANGUAGE CPP, TemplateHaskell #-}
module Language.Haskell.TH.Extras where

import Control.Monad
import Data.Generics
import Data.Maybe
import Language.Haskell.TH
import Language.Haskell.TH.Syntax

intIs64 :: Bool
intIs64 :: Bool
intIs64 = Int -> Integer
forall a. Integral a => a -> Integer
toInteger (Int
forall a. Bounded a => a
maxBound :: Int) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> 2Integer -> Integer -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^32

replace :: (a -> Maybe a) -> (a -> a)
replace :: (a -> Maybe a) -> a -> a
replace = (a -> Maybe a -> a) -> (a -> Maybe a) -> a -> a
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe

composeExprs :: [ExpQ] -> ExpQ
composeExprs :: [ExpQ] -> ExpQ
composeExprs [] = [| id |]
composeExprs [f :: ExpQ
f] = ExpQ
f
composeExprs (f :: ExpQ
f:fs :: [ExpQ]
fs) = [| $f . $(composeExprs fs) |]

nameOfCon :: Con -> Name
nameOfCon :: Con -> Name
nameOfCon (NormalC  name :: Name
name _) = Name
name
nameOfCon (RecC     name :: Name
name _) = Name
name
nameOfCon (InfixC _ name :: Name
name _) = Name
name
nameOfCon (ForallC _ _ con :: Con
con) = Con -> Name
nameOfCon Con
con
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 800
nameOfCon (GadtC [name :: Name
name] _ _)    = Name
name
nameOfCon (RecGadtC [name :: Name
name] _ _) = Name
name
#endif

-- |WARNING: discards binders in GADTs and existentially-quantified constructors
argTypesOfCon :: Con -> [Type]
argTypesOfCon :: Con -> [Type]
argTypesOfCon (NormalC  _ args :: [BangType]
args) = (BangType -> Type) -> [BangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map BangType -> Type
forall a b. (a, b) -> b
snd [BangType]
args
argTypesOfCon (RecC     _ args :: [VarBangType]
args) = [Type
t | (_,_,t :: Type
t) <- [VarBangType]
args]
argTypesOfCon (InfixC x :: BangType
x _ y :: BangType
y)    = (BangType -> Type) -> [BangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map BangType -> Type
forall a b. (a, b) -> b
snd [BangType
x,BangType
y]
argTypesOfCon (ForallC _ _ con :: Con
con) = Con -> [Type]
argTypesOfCon Con
con
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 800
argTypesOfCon (GadtC _ args :: [BangType]
args _)    = (BangType -> Type) -> [BangType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map BangType -> Type
forall a b. (a, b) -> b
snd [BangType]
args
argTypesOfCon (RecGadtC _ args :: [VarBangType]
args _) = [Type
t | (_,_,t :: Type
t) <- [VarBangType]
args]
#endif

nameOfBinder :: TyVarBndr -> Name
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 700
nameOfBinder :: TyVarBndr -> Name
nameOfBinder (PlainTV n :: Name
n)    = Name
n
nameOfBinder (KindedTV n :: Name
n _) = Name
n
#else
nameOfBinder = id
type TyVarBndr = Name
#endif

varsBoundInCon :: Con -> [TyVarBndr]
varsBoundInCon :: Con -> [TyVarBndr]
varsBoundInCon (ForallC bndrs :: [TyVarBndr]
bndrs _ con :: Con
con) = [TyVarBndr]
bndrs [TyVarBndr] -> [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a] -> [a]
++ Con -> [TyVarBndr]
varsBoundInCon Con
con
varsBoundInCon _ = []

namesBoundInPat :: Pat -> [Name]
namesBoundInPat :: Pat -> [Name]
namesBoundInPat (VarP name :: Name
name)             = [Name
name]
namesBoundInPat (TupP pats :: [Pat]
pats)             = [Pat]
pats [Pat] -> (Pat -> [Name]) -> [Name]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Pat -> [Name]
namesBoundInPat
namesBoundInPat (ConP _ pats :: [Pat]
pats)           = [Pat]
pats [Pat] -> (Pat -> [Name]) -> [Name]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Pat -> [Name]
namesBoundInPat
namesBoundInPat (InfixP p1 :: Pat
p1 _ p2 :: Pat
p2)        = Pat -> [Name]
namesBoundInPat Pat
p1 [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ Pat -> [Name]
namesBoundInPat Pat
p2
namesBoundInPat (TildeP pat :: Pat
pat)            = Pat -> [Name]
namesBoundInPat Pat
pat
namesBoundInPat (AsP name :: Name
name pat :: Pat
pat)          = Name
name Name -> [Name] -> [Name]
forall a. a -> [a] -> [a]
: Pat -> [Name]
namesBoundInPat Pat
pat
namesBoundInPat (RecP _ fieldPats :: [FieldPat]
fieldPats)      = (FieldPat -> Pat) -> [FieldPat] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map FieldPat -> Pat
forall a b. (a, b) -> b
snd [FieldPat]
fieldPats [Pat] -> (Pat -> [Name]) -> [Name]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Pat -> [Name]
namesBoundInPat
namesBoundInPat (ListP pats :: [Pat]
pats)            = [Pat]
pats [Pat] -> (Pat -> [Name]) -> [Name]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Pat -> [Name]
namesBoundInPat
namesBoundInPat (SigP pat :: Pat
pat _)            = Pat -> [Name]
namesBoundInPat Pat
pat

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
namesBoundInPat (BangP pat :: Pat
pat)             = Pat -> [Name]
namesBoundInPat Pat
pat
#endif

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 700
namesBoundInPat (ViewP _ pat :: Pat
pat)           = Pat -> [Name]
namesBoundInPat Pat
pat
#endif

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
namesBoundInPat (UnboxedTupP pats :: [Pat]
pats)      = [Pat]
pats [Pat] -> (Pat -> [Name]) -> [Name]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Pat -> [Name]
namesBoundInPat
#endif

namesBoundInPat _                       = []


namesBoundInDec :: Dec -> [Name]
namesBoundInDec :: Dec -> [Name]
namesBoundInDec (FunD name :: Name
name _)                       = [Name
name]
namesBoundInDec (ValD pat :: Pat
pat _ _)                      = Pat -> [Name]
namesBoundInPat Pat
pat

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 800
namesBoundInDec (DataD _ name :: Name
name _ _ _ _)              = [Name
name]
namesBoundInDec (NewtypeD _ name :: Name
name _ _ _ _)           = [Name
name]
#else
namesBoundInDec (DataD _ name _ _ _)                = [name]
namesBoundInDec (NewtypeD _ name _ _ _)             = [name]
#endif

namesBoundInDec (TySynD name :: Name
name _ _)                   = [Name
name]
namesBoundInDec (ClassD _ name :: Name
name _ _ _)               = [Name
name]
namesBoundInDec (ForeignD (ImportF _ _ _ name :: Name
name _))   = [Name
name]

#if defined(__GLASGOW_HASKELL__)
#if __GLASGOW_HASKELL__ >= 800
namesBoundInDec (OpenTypeFamilyD (TypeFamilyHead name :: Name
name _ _ _))     = [Name
name]
namesBoundInDec (ClosedTypeFamilyD (TypeFamilyHead name :: Name
name _ _ _) _) = [Name
name]
#elif __GLASGOW_HASKELL__ >= 612
namesBoundInDec (FamilyD _ name _ _)                = [name]
#endif
#endif

namesBoundInDec _                                   = []

genericalizeName :: Name -> Name
genericalizeName :: Name -> Name
genericalizeName = String -> Name
mkName (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase

-- Genericalize all names defined at the top level, to fix the lunacy introduced in GHC 7.2.
-- Why they should be fresh is beyond me; it really seems absurd because there is no way whatsoever
-- to refer to names known to be bound in [d||] quotes other than to scrounge around inside the
-- generated 'Dec's.
genericalizeDecs :: [Dec] -> [Dec]
genericalizeDecs :: [Dec] -> [Dec]
genericalizeDecs decs :: [Dec]
decs = (forall a. Data a => a -> a) -> [Dec] -> [Dec]
(forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((Name -> Name) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT Name -> Name
fixName) [Dec]
decs
    where
        -- get all names bound in the decs and make them generic
        -- at every occurence in decs.
        names :: [Name]
names = [Dec]
decs [Dec] -> (Dec -> [Name]) -> [Name]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Dec -> [Name]
namesBoundInDec
        genericalizedNames :: [(Name, Name)]
genericalizedNames = [ (Name
n, Name -> Name
genericalizeName Name
n) | Name
n <- [Name]
names]
        fixName :: Name -> Name
fixName = (Name -> Maybe Name) -> Name -> Name
forall a. (a -> Maybe a) -> a -> a
replace (Name -> [(Name, Name)] -> Maybe Name
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` [(Name, Name)]
genericalizedNames)

headOfType :: Type -> Name
headOfType :: Type -> Name
headOfType (ForallT _ _ ty :: Type
ty) = Type -> Name
headOfType Type
ty
headOfType (VarT name :: Name
name) = Name
name
headOfType (ConT name :: Name
name) = Name
name
headOfType (TupleT n :: Int
n) = Int -> Name
tupleTypeName Int
n
headOfType ArrowT = ''(->)
headOfType ListT = ''[]
headOfType (AppT t :: Type
t _) = Type -> Name
headOfType Type
t

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
headOfType (SigT t :: Type
t _) = Type -> Name
headOfType Type
t
#endif

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
headOfType (UnboxedTupleT n :: Int
n) = Int -> Name
unboxedTupleTypeName Int
n
#endif

occursInType :: Name -> Type -> Bool
occursInType :: Name -> Type -> Bool
occursInType var :: Name
var ty :: Type
ty = case Type
ty of
        ForallT bndrs :: [TyVarBndr]
bndrs _ ty :: Type
ty
            | (Name -> Bool) -> [Name] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Name
var Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
==) ((TyVarBndr -> Name) -> [TyVarBndr] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr -> Name
nameOfBinder [TyVarBndr]
bndrs)
                -> Bool
False
            | Bool
otherwise
                -> Name -> Type -> Bool
occursInType Name
var Type
ty
        VarT name :: Name
name
            | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
var -> Bool
True
            | Bool
otherwise   -> Bool
False
        AppT ty1 :: Type
ty1 ty2 :: Type
ty2 -> Name -> Type -> Bool
occursInType Name
var Type
ty1 Bool -> Bool -> Bool
|| Name -> Type -> Bool
occursInType Name
var Type
ty2
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
        SigT ty :: Type
ty _ -> Name -> Type -> Bool
occursInType Name
var Type
ty
#endif
        _ -> Bool
False