module Network.TLS.Crypto.ECDH
    (
    -- * ECDH types
      ECDHParams(..)
    , ECDHPublic
    , ECDHPrivate(..)

    -- * ECDH methods
    , ecdhPublic
    , ecdhPrivate
    , ecdhParams
    , ecdhGenerateKeyPair
    , ecdhGetShared
    , ecdhUnwrap
    , ecdhUnwrapPublic
    ) where

import Network.TLS.Util.Serialization (i2osp, lengthBytes)
import Network.TLS.Extension.EC
import qualified Crypto.PubKey.ECC.DH as ECDH
import qualified Crypto.Types.PubKey.ECC as ECDH
import qualified Crypto.PubKey.ECC.Prim as ECC (isPointValid)
import Crypto.Random (CPRG)
import Data.ByteString (ByteString)
import Data.Word (Word16)

data ECDHPublic     = ECDHPublic ECDH.PublicPoint Int {- byte size -}
                      deriving (Show,Eq)
newtype ECDHPrivate = ECDHPrivate ECDH.PrivateNumber deriving (Show,Eq)
data ECDHParams     = ECDHParams ECDH.Curve ECDH.CurveName deriving (Show,Eq)
type ECDHKey        = ByteString

ecdhPublic :: Integer -> Integer -> Int -> ECDHPublic
ecdhPublic x y siz = ECDHPublic (ECDH.Point x y) siz

ecdhPrivate :: Integer -> ECDHPrivate
ecdhPrivate = ECDHPrivate

ecdhParams :: Word16 -> ECDHParams
ecdhParams w16 = ECDHParams curve name
  where
    Just name = toCurveName w16 -- FIXME
    curve = ECDH.getCurveByName name

ecdhGenerateKeyPair :: CPRG g => g -> ECDHParams -> ((ECDHPrivate, ECDHPublic), g)
ecdhGenerateKeyPair rng (ECDHParams curve _) =
    let (priv, g') = ECDH.generatePrivate rng curve
        siz        = pointSize curve
        point      = ECDH.calculatePublic curve priv
        pub        = ECDHPublic point siz
     in ((ECDHPrivate priv, pub), g')

ecdhGetShared :: ECDHParams -> ECDHPrivate -> ECDHPublic -> Maybe ECDHKey
ecdhGetShared (ECDHParams curve _)  (ECDHPrivate priv) (ECDHPublic point _)
    | ECC.isPointValid curve point =
        let ECDH.SharedKey sk = ECDH.getShared curve priv point
         in Just $ i2osp sk
    | otherwise =
        Nothing

-- for server key exchange
ecdhUnwrap :: ECDHParams -> ECDHPublic -> (Word16,Integer,Integer,Int)
ecdhUnwrap (ECDHParams _ name) point = (w16,x,y,siz)
  where
    w16 = case fromCurveName name of
        Just w  -> w
        Nothing -> error "ecdhUnwrap"
    (x,y,siz) = ecdhUnwrapPublic point

-- for client key exchange
ecdhUnwrapPublic :: ECDHPublic -> (Integer,Integer,Int)
ecdhUnwrapPublic (ECDHPublic (ECDH.Point x y) siz) = (x,y,siz)
ecdhUnwrapPublic _                                 = error "ecdhUnwrapPublic"

pointSize :: ECDH.Curve -> Int
pointSize (ECDH.CurveFP curve) = lengthBytes $ ECDH.ecc_p curve
pointSize _ = error "pointSize" -- FIXME