{- |
    Module      :  $Header$
    Description :  Elimination of recursive data declarations
    Copyright   :  (c) 2009, Holger Siegel
    License     :  BSD-3-clause

    Maintainer  :  bjp@informatik.uni-kiel.de
    Stability   :  experimental
    Portability :  portable

    Turn recursive data declarations into recursive function calls.

    Only single recursive declarations are transformed, mutually recursive
    declarations are left unchanged.
    You should therefore use the transformation UnMutual first.
-}

module Curry.ExtendedFlat.LiftLetrec (liftLetrecProg) where

import Data.List
import Control.Monad.State (State, get, put, modify, runState)
import Data.Maybe
import qualified Data.Map as Map
import qualified Data.Set as Set

import Curry.ExtendedFlat.Type
import Curry.ExtendedFlat.Goodies
import Curry.ExtendedFlat.MonadicGoodies

data LifterState = LifterState
  { modname       :: String
  , currentFunc   :: String
  , globals       :: Set.Set QName
  , globalCounter :: Map.Map QName Int
  , localCounter  :: Int
  , lifted        :: Map.Map QName FuncDecl
  }

type Bind = (VarIndex, Expr)    -- (name, value)
type LiftMonad = State LifterState

-- |Convert recursive data declarations into recursive function calls
liftLetrecProg :: Prog -> Prog
liftLetrecProg prog = updProg id id id (++ fdecls) id prog' where
  state = LifterState
            { modname = progName prog
            , currentFunc = "anonymous"
            , globals = Set.fromList g
            , globalCounter = Map.fromList $ zip g (repeat 1)
            , localCounter = 0
            , lifted = Map.empty
            }
  g = allGlobals prog
  (prog', state') = runState (updProgFuncsM run prog) state
  fdecls = Map.elems (lifted state')
  run fdecl = do
      let fname = localName (funcName fdecl)
      modify (\st -> st { currentFunc  = fname
                        , localCounter = (maximum . map idxOf . allVarsInFunc) fdecl
                        })
      fdecl' <- updFuncLetsM liftRecursion fdecl
      modify (\st -> st { currentFunc = "anonymous" })
      return fdecl'

liftRecursion :: [Bind] -> Expr -> LiftMonad Expr
liftRecursion [(b, rhs)] body
  | b `elem` fv = do
      globalcall <- mkLiftedFunction (typeofVar b) b rhs (fv \\ [b])
      return (Let [(b, globalcall)] body)
  | otherwise  = return (Let [(b, rhs)] body)
  where fv = fvs rhs
liftRecursion bs body = return (Let bs body)

mkLiftedFunction :: Maybe TypeExpr -> VarIndex -> Expr -> [VarIndex] -> LiftMonad Expr
mkLiftedFunction t v rhs fv = do
  name <- newGlobalName t
  st <- get
  let fcall = (Comb FuncCall name (map Var fv))
  let fdecl = Func name (length fv) Private (fromMaybe (TVar 0) t) (Rule fv (Let [(v,fcall)] rhs))
  put st { lifted = Map.insert name fdecl (lifted st)
         , globals = Set.insert name (globals st)
         }
  return fcall

newGlobalName :: Maybe TypeExpr -> LiftMonad QName
newGlobalName t = do
  st <- get
  let qn = QName Nothing t (modname st) (currentFunc st)
  let counter = Map.findWithDefault 1 qn (globalCounter st)
  put st { globalCounter = Map.insert qn (counter + 1) (globalCounter st) }
  let qn' = QName Nothing t (modname st) (localName qn ++ "_" ++ show counter)
  if qn' `Set.member` globals st
    then newGlobalName t
    else return qn'

allGlobals :: Prog -> [QName]
allGlobals prog = [n | Func n _ _ _ _ <- progFuncs prog]