from typing import Final, Dict, Optional,  List, TypeVar, Sequence, Literal

import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions import Normal
from core import EnvInfo
import numpy as np

from core import VType, Taskinfo, VarID, CausalGraph, EnvObjClass
from utils.typings import ObjectTensors, NamedTensors, ObjectDistributions, NamedDistributions, EnvModel
import utils
import alg.functional as F

from .modules import MultiLinear, Attention, attention
from utils.typings import EnvModel, TransitionModel

class RewardPredictor(nn.Module):
    def forward(self, raw_attributes: ObjectTensors,
                objmasks: Optional[NamedTensors] = None) -> Normal:
        raise NotImplementedError
    
    def make_envmodel(self, transition_model: TransitionModel) -> EnvModel:
        def envmodel(attrs: ObjectTensors, objmask: Optional[NamedTensors]):
            state = transition_model(attrs, objmask)
            reward = self.forward(attrs, objmask)
            return state, reward
        return envmodel


class MLPRewardPredictor(RewardPredictor):
    class Args(utils.Struct):
        def __init__(self) -> None:
            self.dim_h1 = 128
            self.dim_h2 = 128
            self.dim_h3 = 128

    def __init__(self, info: Taskinfo, args: 'MLPRewardPredictor.Args',
                 device: torch.device, dtype: torch.dtype,):
        super().__init__()

        self.info = info
        self.dtype = dtype
        self.device = device

        d_in = sum(info.v(idx).size for idx in info.input_variables)

        self.f = nn.Sequential(
            nn.Linear(d_in, args.dim_h1, device=device, dtype=dtype),
            nn.ReLU(),
            nn.Linear(args.dim_h1, args.dim_h2, device=device, dtype=dtype),
            nn.ReLU(),
            nn.Linear(args.dim_h2, args.dim_h3, device=device, dtype=dtype),
            nn.ReLU(),
            nn.Linear(args.dim_h3, 2, device=device, dtype=dtype),
        )

    def forward(self, raw_attributes: ObjectTensors, objmasks: Optional[NamedTensors] = None):
        if objmasks is not None:
            raise ValueError
        
        envinfo = self.info.envinfo
        inputs = {
            clsname: {
                attrname: envinfo.v(clsname, attrname).raw2input(raw)
                for attrname, raw in temp.items()}
            for clsname, temp in raw_attributes.items()}
        inputs = [vidx(inputs) for vidx in self.info.input_variables]
        x = torch.cat(inputs, dim=1)
        r: Tensor = self.f(x)
        mean = r[..., 0]
        std = nn.functional.softplus(r[..., 1]) + 1e-6

        return Normal(mean, std)


class _ClassEncoder(nn.Module):
    def __init__(self, c: EnvObjClass,
                 dim_hidden: int, dim_encoding: int,
                 device: torch.device, dtype: torch.dtype):
        super().__init__()

        dim_in = sum(c.v(a).size for a in c.attrnames())
        self.c = c
        
        self.f = nn.Sequential(
            nn.Linear(dim_in, dim_hidden, device=device, dtype=dtype),
            nn.LeakyReLU(),
            nn.Linear(dim_hidden, dim_encoding, device=device, dtype=dtype),
        )

    def forward(self, raw_attributes_c: NamedTensors):
        '''
        raw_attributes_c: dict[attrname, (batch, object, feature)]
        return: (batch, object, feature)
        '''

        c = self.c
        x = torch.cat([c.v(a).raw2input(raw_attributes_c[a]) 
                       for a in c.attrnames()], dim=2)
        x: Tensor = self.f(x)
        return x


# class _ClassInferer(nn.Module):
#     def __init__(self, 
#                  dim_encoding: int, dim_k: int, dim_v: int,
#                  n_head: Optional[int], device: torch.device, dtype: torch.dtype):
#         super().__init__()
# 
#         self.attn = Attention(device, dtype, n_head, 
#             transform_k=(dim_encoding, dim_k),
#             transform_q=(dim_encoding, dim_k),
#             transform_v=(dim_encoding, dim_v)
#         )
# 
#         self.mlp1 = nn.Sequential(
#             nn.Linear(dim_encoding + dim_v, dim_encoding, device=device, dtype=dtype),
#             nn.ReLU(),
#             nn.Linear(dim_encoding, dim_encoding, device=device, dtype=dtype)
#         )
# 
#     def forward(self, encoding_c: Tensor, encodings: Tensor,
#                 objmask: Optional[Tensor] = None) -> Tensor:
#         if objmask is not None:
#             objmask = objmask.unsqueeze(1)  # batchsize * 1 * n_obj
#         
#         x =  self.attn.forward(encoding_c, encodings, encodings, objmask)  # batchsize * n_obj_c * dim_v
#         x = torch.cat((x, encoding_c), dim=-1)  # batchsize * n_obj_c * (dim_v + dim_e)
#         x = self.mlp1.forward(x)
# 
#         return x


class OORewardPredictor(RewardPredictor):
    class Args(utils.Struct):
        def __init__(self) -> None:
            self.dim_encoder: int = 64
            self.dim_encoding: int = 64
            self.dim_k = 64
            self.dim_v = 64
            self.n_head: Optional[int] = 4
            self.dim_decoder: int = 32

    def __init__(self, envinfo: EnvInfo, args: 'OORewardPredictor.Args',
                  device: torch.device, dtype: torch.dtype):
        super().__init__()

        self.encoders = {c.name: _ClassEncoder(c, args.dim_decoder, args.dim_encoding,
                                               device, dtype)
                         for c in envinfo.classes}
        for cname, module in self.encoders.items():
            self.add_module(f"{cname}_encoder", module)

        self.attn_1 = Attention(device, dtype, args.n_head, 
            transform_k=(args.dim_encoding, args.dim_k),
            transform_q=(args.dim_encoding, args.dim_k),
            transform_v=(args.dim_encoding, args.dim_v)
        )

        self.mlp1 = nn.Sequential(
            nn.Linear(args.dim_encoding + args.dim_v, args.dim_encoding, device=device, dtype=dtype),
            nn.ReLU(),
            nn.Linear(args.dim_encoding, args.dim_encoding, device=device, dtype=dtype)
        )

        self.q = nn.parameter.Parameter(
            torch.randn(args.dim_k, dtype=dtype, device=device))
        
        self.attn_2 = Attention(device, dtype, args.n_head, 
            transform_k=(args.dim_encoding, args.dim_k),
            transform_v=(args.dim_encoding, args.dim_v)
        )

        self.decoder = nn.Sequential(
            nn.Linear(args.dim_v, args.dim_decoder, dtype=dtype, device=device),
            nn.ReLU(),
            nn.Linear(args.dim_decoder, 2, device=device, dtype=dtype),
        )

    def forward(self, raws: ObjectTensors,
                objmasks: Optional[NamedTensors] = None):
        '''
        reward_encoding: (batchsize, n_obj_c, dim_in)
        '''

        cnames = list(raws.keys())

        # (batchsize, n_obj, dim_encoding)
        enc: Tensor = torch.cat([
            self.encoders[cname].forward(raws[cname])
            for cname in cnames
        ], dim=1)

        # (batchsize, n_obj)
        if objmasks is None:
            mask = None
        else:
            mask = torch.cat([objmasks[cname] for cname in cnames], dim=1)  #  batch * n_obj
            mask = mask.unsqueeze(1)  #  batch * 1 * n_obj

        v = self.attn_1.forward(enc, enc, enc, mask)
        enc_ = torch.cat((enc, v), dim=2)  # (batchsize, n_obj, dim_encoding + dim_v)

        enc_ = self.mlp1.forward(enc_)  # (batchsize, n_obj, dim_encoding)
        q = self.q.unsqueeze(0)  # (1, dim_k)
        x = self.attn_2.forward(q, enc_, enc_, mask)  # (batchsize, 1, dim_v)
        x = x.squeeze(1)  # (batchsize, dim_v)

        out: Tensor = self.decoder.forward(x)  # batchsize * 2
        mean = out[..., 0]
        std = nn.functional.softplus(out[..., 1]) + 1e-6

        # mean_nan = torch.isnan(mean)
        # std_nan = torch.isnan(std)
        # if torch.any(mean_nan) or torch.any(std_nan):
        #     print("Warning: nan found in reward prediction. Replacing with 0.")
        #     mean = torch.masked_fill(mean, mean_nan, 0.)
        #     std = torch.masked_fill(std, std_nan, 1e-6)

        return Normal(mean, std)
