from typing import List, Dict,Callable

import torch
import torch.nn as nn
import utils
from dct import LinearDCT
from disentangle.AbstractTrainer import AbstractLitModule
from models.encoders import MLP, TransformerEncoder
from scipy.special import logit
from solver.ode_layer import ODESYSLayer
import numpy as np

class TI_MNN_Module(torch.nn.Module):
    def __init__(
        self,
        batch_size: int = 10,
        order: int = 2,  # order of polynomial of MNN
        state_dim: int = 1,
        n_step: int = 60,
        mlp_enc: bool = True,
        dct_layer: bool = False,
        code_sharing: Dict[int,List[int]] = None,
        feature_sharing_fn: Callable = None,
        **kwargs,
    ):
        super().__init__()

        self.n_step = n_step
        # the order of derivatives we use to approximate,
        # not necessarily the ground truth order of the ODE
        self.order = order

        # state dimension
        self.state_dim = state_dim

        # hidden dim for submodules
        self.hidden_dim = kwargs.get("hidden_dim", 1024)
        self.n_views = kwargs.get("n_views", 2)

        self.batch_size = batch_size
        self.n_coeff = self.n_step * (self.order + 1)
        self.n_iv_steps = kwargs.get("n_iv_steps", 1)  # how many first steps information

        # for initial value problem  n_iv_steps = 1
        self.step_dim = (self.n_step - 1) * self.state_dim
        self.ode_layer = ODESYSLayer(
            bs=batch_size * self.n_views,
            n_ind_dim=1,
            order=self.order,
            n_equations=self.state_dim,  # equals number of states
            n_dim=self.state_dim,
            n_iv=1,
            n_step=self.n_step,
            n_iv_steps=self.n_iv_steps,
            solver_dbl=True,
        )
        # define the dimensions
        self.rhs_dim = self.state_dim * self.n_step  # time_steps * state_dim

        self.coeff_dim = (
            self.ode_layer.n_ind_dim
            * self.ode_layer.n_equations
            # * self.ode_layer.n_step
            * self.ode_layer.n_dim
            * (self.order + 1)
        )  # self.state_dim*(self.order+1) #self.rhs_dim*(self.order+1)
        self.param_dim = kwargs.get("param_dim", 20)
        self.embedding_dim = kwargs.get("embedding_dim", 64)

        # add dct transform if necessary:
        self.dct_layer: bool = dct_layer
        self.freq_frac_to_keep: float = kwargs.get("freq_frac_to_keep", 0.5)
        if dct_layer:
            self.dct: nn.Module = LinearDCT(self.n_step, "dct", norm="ortho").double()
            self.idct: nn.Module = LinearDCT(self.n_step, "idct", norm="ortho").double()
            input_dim = int(self.freq_frac_to_keep * self.n_step) * self.state_dim
        else:
            input_dim = self.rhs_dim

        # by default use mlps for encoding
        ##############################################
        #TODO: only for checking mnn
        
        self.param_dim = input_dim
        self.params_enc = lambda x: x
        # decode from params to rhs
        self.rhs_t = MLP(input_dim=self.param_dim, output_dim=self.rhs_dim, hidden_dim=self.hidden_dim, num_layers=3)

        self.coeffs_mlp = MLP(
            input_dim=self.param_dim,
            output_dim=self.coeff_dim,
            hidden_dim=self.hidden_dim,
            num_layers=3,
        )

        self.pre_steps_mlp = nn.Sequential(
            nn.Linear(self.param_dim, self.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.LeakyReLU(),
        )

        self.steps_layer = nn.Linear(self.hidden_dim, self.step_dim)

        # set step bias to make initial step 0.1
        step_bias = logit(0.1)
        self.steps_layer.weight.data.fill_(0.0)
        self.steps_layer.bias.data.fill_(step_bias)
        
        
        self.code_sharing = code_sharing
        self.feature_sharing_fn = feature_sharing_fn

    def state_transform(self, states: torch.Tensor):
        # states: [n_views, bs, n_step, state_dim]
        freqs: torch.Tensor = self.dct(states.swapaxes(-1, -2)).swapaxes(-1, -2)
        return freqs[..., : int(self.freq_frac_to_keep * self.n_step), :]

    def state_inverse_transform(self, freqs: torch.Tensor):
        # freqs: [bs, n_freqs_to_keep, state_dim]
        # fill the high-frequency that we droped before with zero
        freqs: torch.Tensor = torch.cat(
            [freqs, torch.zeros(*freqs.shape[:2], self.n_step - freqs.shape[-2], freqs.shape[-1]).type_as(freqs)],
            dim=-2,
        )
        return self.idct(freqs.swapaxes(-1, -2)).swapaxes(-1, -2)

    def decode_from_params(self, params: torch.Tensor):
        # Righthandside of the ODE
        rhs: torch.Tensor = self.rhs_t(params)  # (bs, n_step*state_dim)
        # Time varying ODE coefficients
        coeffs: torch.Tensor = self.coeffs_mlp(params)  # (bs, state_dim*(order+1))
        coeffs = coeffs[:, None, :].repeat(1, self.n_step, 1)  # (bs, n_step, state_dim*(order+1)
        # Learned steps
        _steps = self.pre_steps_mlp(params)  # (bs, hidden_dim)
        steps: torch.Tensor = self.steps_layer(_steps)  # (bs, n_steps-1)
        steps: torch.Tensor = torch.sigmoid(steps).clip(min=0.001, max=0.999)  # (bs, n_steps-1)
        return rhs, coeffs, steps

    def solve(self, params: torch.Tensor, iv_rhs: torch.Tensor):
        rhs, coeffs, steps = self.decode_from_params(params)
        u0, u1, u2, eps, steps = self.ode_layer(coeffs=coeffs, rhs=rhs, iv_rhs=iv_rhs, steps=steps)
        u0 = u0.squeeze(1)  # (n_views*bs, ts, state_dim)
        return u0.reshape(self.n_views, -1, self.n_step, self.state_dim)

    def feature_sharing(self, params: torch.Tensor, **kwargs):
        # this should be inherent to the data generating process, so it should be an attribute
        # to the corresponding dataset
        return params

    def forward(self, states: torch.Tensor, **kwargs):
        # states: (bs, n_step, state_dim)
        # extarct iv steps before dct layer, make sure it is in the time domain
        iv_rhs = states[..., : self.n_iv_steps, :]  # (bs, n_iv_steps, state_dim)
        if self.dct_layer:
            states: torch.Tensor = self.state_transform(states)
        # parameter encoding
        params: torch.Tensor = self.params_enc(
            states.reshape(-1, states.shape[-2] * states.shape[-1])
        )  # (bs, param_dim)
        params: torch.Tensor = params.reshape(-1, self.batch_size, self.param_dim)  # (n_views, bs, param_dim)
        iv_rhs = iv_rhs.reshape(
            -1, self.batch_size, self.n_iv_steps, self.state_dim
        )  # (n_views, bs, n_iv_steps, state_dim)
        shared = params#self.feature_sharing(params, **kwargs)
        # no matter apply dct layer or not, u0 always in time domain
        # shape: [n_views, bs, ts, state_dim]
        u0s = self.solve(shared.view(-1, self.param_dim), iv_rhs.view(-1, self.n_iv_steps, self.state_dim))
        if self.dct_layer:
            u0s = self.state_transform(u0s.double())  # to convert u0s to the freq domain; make sure it is double
        return states, u0s, params, shared  # u0: [bs, ts, state_dim], params: [bs, param_dim]


class TIMechanisticLitModule(AbstractLitModule):
    def __init__(
        self,
        learning_rate: torch.float64 = 1e-5,
        alignment_reg=10,
        eval_metrics: List[str] = [],
        **model_kwargs,
    ):
        super().__init__(
            learning_rate=learning_rate,
            eval_metrics=eval_metrics,
            **model_kwargs,
        )
        model_kwargs['feature_sharing_fn'] = self.feature_sharing_fn
        
        for k, v in model_kwargs.items():
            setattr(self, k, v)
        # save hyperparameters
        self.save_hyperparameters()

        self.model = TI_MNN_Module(**model_kwargs).double()

        # # add xavier initializtaion
        # if self.train():
        #     self.model.train()
        #     utils.xavier_init(self.model)

        self.loss = nn.MSELoss().double()
        self.alignment_reg = alignment_reg
        self.type = torch.float64
        
        # if self.training:
        #     utils.xavier_init(self.model)

    def forward(self, states: torch.Tensor, **kwargs):
        return self.model(states, **kwargs)

    def training_step(self, batch, batch_idx):
        # [n_views, bs, ts, state_dim]
        batch["states"] = batch["states"].to(self.type)
        # depending on if we have dct layer or not, the output states could be in freq space
        states, u0s, params, shared = self.forward(
            **batch
        )  # here: [n_views * bs, ts, state_dim], [n_views * bs, param_dim]
        # states = states.reshape(-1, self.model.n_step, self.model.state_dim)
        # TODO: remove this assert later
        # u0: [n_views, bs, ts, state_dim], params: [n_views, bs, param_dim]
        recon_loss = self.loss(u0s.double().reshape(*states.shape), 
                               states.double())
        self.log("train_loss", recon_loss, prog_bar=True, on_step=True, on_epoch=True)
        
        return recon_loss

    def validation_step(self, batch, batch_idx):
        states = batch["states"].to(self.type)
        if self.model.dct_layer:
            states = self.model.state_transform(states)
        params = self.model.params_enc(states.reshape(-1, states.shape[-2] * states.shape[-1]))
        params = self.model.coeffs_mlp(params.reshape(-1, self.model.param_dim)).cpu().numpy()
        self.misc["pred_params"].append(params)
        if "gt_params" in batch:
            if isinstance(batch["gt_params"], dict):
                self.misc["gt_params"].append(torch.stack(list(batch["gt_params"].values()), -1).cpu().numpy())
            else:
                self.misc["gt_params"].append(batch["gt_params"].cpu().numpy())

    # TODO: possible to forecast over the whole trajectory
    def predict_step(self, batch, batch_id):
        # select the first half as encoder input
        states = batch["states"].to(self.type)
        input_states = states[..., : self.model.n_step, :]
        future_states = states[..., self.model.n_step :, :]
        if self.model.dct_layer:
            input_states = self.model.state_transform(input_states)

        params = self.model.params_enc(input_states.reshape(-1, input_states.shape[-2] * input_states.shape[-1]))

        iv_rhs = future_states[..., : self.model.n_iv_steps, :].reshape(
            -1, self.model.batch_size, self.model.n_iv_steps, self.model.state_dim
        )  # (n_views, bs, n_iv_steps, state_dim)
        # no matter apply dct layer or not, u0 always in time domain
        # shape: [n_views, bs, ts, state_dim]
        u0s = self.model.solve(
            params.view(-1, self.model.param_dim), iv_rhs.view(-1, self.model.n_iv_steps, self.model.state_dim)
        )
        forecast_loss = self.loss(u0s.double(), future_states.double())
        return forecast_loss
