import torch
import torch.nn as nn
from data_classes import PICalibData, FcModel, TrainLoader, ValLoader
import torchcde
from typing import Tuple
from torch.utils.data import DataLoader, SequentialSampler
import sys

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TrainNeuralCDE:
    def __init__(self, params):
        self.params = params
        self.criterion = torch.nn.MSELoss()

    def train(self, calib_data: Tuple[PICalibData]):
        
        train_loader, val_loader = self.data_loaders(calib_data)
        cde_hidden = tuple(self.params['cde_nodes'] \
                      for _ in range(self.params['cde_layers']))
        decoder_hidden = tuple(self.params['decoder_nodes'] \
                          for _ in range(self.params['decoder_layers']))
        model = NeuralCDE(input_channels=calib_data[0].X_ctx.shape[-1],
                          hidden_channels=self.params['hidden_channels'],
                          cde_hidden=cde_hidden,
                          output_channels=calib_data[2].error.shape[-1],
                          decoder_hidden=decoder_hidden)
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters())

        for epoch in range(self.params['epochs']):
            train_losses = []
            
            for batch in train_loader:
                _, X_ctx_sim_coeff, errors = batch

                X_ctx_sim_coeff = X_ctx_sim_coeff.to(device) # [32, 297, 28]
                errors = errors.to(device) # [32]

                pred_y = model(X_ctx_sim_coeff)
                loss = self.criterion(pred_y, errors)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                train_losses.append(loss.item())
            train_loss = torch.mean(torch.Tensor(train_losses))
            
            if epoch % 5 == 0:
                val_losses = []
                for batch in val_loader:
                    _, X_ctx_sim_coeff, errors, _, _ = batch
                    X_ctx_sim_coeff = X_ctx_sim_coeff.to(device)
                    errors = errors.to(device)

                    pred_y = model(X_ctx_sim_coeff)
                    loss = self.criterion(pred_y, errors)
                    val_losses.append(loss.item())
                val_loss = torch.mean(torch.Tensor(val_losses))

            print(f"Epoch: {epoch} | Train loss: {train_loss} | Val loss: {val_loss}")

        return 

    def data_loaders(self, calib_data: Tuple[PICalibData]):
        train_dataset = TrainLoader(X_ctx_true=calib_data[0].X_ctx_coeffs,
                                    X_ctx_sim=calib_data[2].X_ctx_coeffs,
                                    errors=calib_data[2].error)
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=self.params['batch_size'],
                                  shuffle=True)
        val_dataset = ValLoader(X_ctx_true=calib_data[1].X_ctx_coeffs,
                                X_ctx_sim=calib_data[3].X_ctx_coeffs,
                                errors=calib_data[3].error,
                                Y=calib_data[1].Y,
                                Y_pred=calib_data[3].Y)
        val_sampler = SequentialSampler(val_dataset)
        val_loader = DataLoader(dataset=val_dataset,
                                batch_size=self.params['batch_size'],
                                shuffle=False)
        return train_loader, val_loader


class CDEFunc(FcModel):
    def __init__(self, input_channels, hidden_channels, hidden,\
                  dropout = 0, dropout_at_first=False, dropout_after_last=False,\
                      dropout_intermediate=False, tanh_after_last=False):
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.cde_func_out_dim = self.input_channels*self.hidden_channels
        super().__init__(self.hidden_channels, self.cde_func_out_dim, hidden,\
                          dropout, dropout_at_first, dropout_after_last, \
                            dropout_intermediate, tanh_after_last)

    def forward(self, t, z):
        z = self.linear_stack(z)
        z = z.reshape(-1, self.hidden_channels, self.input_channels)
        return z

class Decoder(FcModel):
    def __init__(self, hidden_channels, output_channels, hidden,\
                  dropout = 0, dropout_at_first=False, dropout_after_last=False,\
                      dropout_intermediate=False, tanh_after_last=False):
        self.hidden_channels = hidden_channels
        self.output_channels = output_channels # equal to state dimension 
        self.hidden = hidden
        super().__init__(self.hidden_channels, self.output_channels, hidden,\
                          dropout, dropout_at_first, dropout_after_last, \
                            dropout_intermediate, tanh_after_last)

    def forward(self, z):
        errors_t = self.linear_stack(z)
        return errors_t

class Initial(FcModel):
    def __init__(self, input_channels, hidden_channels, hidden,\
                  dropout = 0, dropout_at_first=False, dropout_after_last=False,\
                      dropout_intermediate=False, tanh_after_last=False):
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        super().__init__(self.input_channels, self.hidden_channels, hidden,\
                          dropout, dropout_at_first, dropout_after_last, \
                            dropout_intermediate, tanh_after_last)

    def forward(self, x_0):
        z_0 = self.linear_stack(x_0)
        return z_0    

class NeuralCDE(nn.Module):
    def __init__(self, input_channels, hidden_channels, cde_hidden, output_channels, decoder_hidden):
        super(NeuralCDE, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.cde_hidden = cde_hidden
        self.output_channels = output_channels
        self.decoder_hidden = decoder_hidden
        self.cde_func = CDEFunc(self.input_channels, self.hidden_channels,\
                             cde_hidden, tanh_after_last=True)
        self.decoder = Decoder(self.hidden_channels, self.output_channels,\
                                       self.decoder_hidden)
        self.initial = Initial(self.input_channels, hidden_channels, ())

    def forward(self, coeffs):
        X = torchcde.CubicSpline(coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0]) # [32,3]; Interval: [first_step,last_step]
        z0 = self.initial(X0)
        #z0 = torch.zeros((coeffs.shape[0],self.hidden_channels)).to(device)
        ######################
        # Actually solve the CDE.
        ######################
        ts = torch.arange(X.interval[0],X.interval[1]+1).to(device)
        z_full = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.cde_func,
                              t=ts)
        pred_error = self.decoder(z_full)

        return pred_error