# use resnet50 as backbone

import torch
import math
import time

import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torchdyn
from torchdyn.core import NeuralODE
import pytorch_lightning as pl
from torch import optim
import torch.functional as F
import wandb

from utils.visualize import *
from utils.metric_calc import *
from utils.sde import SDE
from utils.positional_encoding import *
from utils.wrapper import torch_wrapper_tv

class RNN_LSTM_cond(torch.nn.Module):
    def __init__(self,
                 dim,
                 num_layers=50,
                 hidden_size=3,
                 treatment_cond=0,
                 time_dim=NUM_FREQS*2,
                 clip=1e-2,
                 memory=3,

                 ):
        super().__init__()
        self.dim = dim + 1 #time dim in input dim
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.treatment_cond = treatment_cond # number of conditionals
        self.time_dim = time_dim
        self.lstm = nn.LSTM(input_size=(dim+self.treatment_cond+self.time_dim+ (dim * memory)), 
                            hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, self.dim)
        # self.net = torch.nn.Sequential(
        #     self.input_layer,
        #     self.resnet,
        #     torch.nn.Flatten(),
        #     torch.nn.Linear(self.dim, self.dim),

        # )
        self.clip = clip
        self.memory = memory

    def train_net(self, x):
        out, _ = self.lstm(x) # only 1 step usually so no need for hidden states
        return self.fc(out)

    def encoding_function(self, x):
        return positional_encoding_tensor(x)
    
    def forward_train(self, x):
        """forward pass
        Assume first two dimensions are x, c, then t
        """
        time_tensor = x[:,-1]
        encoded_time_span = self.encoding_function(time_tensor).reshape(-1, NUM_FREQS * 2)
        new_x = torch.cat([x[:,:-1], encoded_time_span], dim = 1).to(torch.float32) # encoded time

        # @TODO: Separate memory and other features (for now, just concat)
        # new_x = new_x.unsqueeze(1) # was n x d, need n x 1 x d (time dimension)

        result = self.train_net(new_x)
        return torch.cat([result[:,:-1], x[:,2:-1],result[:,-1].unsqueeze(1)], dim=1)

    def forward(self, x):
        """ call forward_train for training
            x here is x_t
            xt = (t)x1 + (1-t)x0
            (xt - tx1)/(1-t) = x0
        """
        x1 = self.forward_train(x)
        x1_coord = x1[:,:2]
        t = x[:,-1]
        pred_time_till_t1 = x1[:,-1]
        # print("time"+str(t))
        x_coord = x[:,:2]
        # vt = (x1 - xt)/(1-t)
        # vt = (x1_coord - (x_coord - t*x1_coord))/torch.clip((pred_time_till_t1),min=1e-2)
        if self.clip is None:
            vt = (x1_coord - x_coord)/(pred_time_till_t1)
        else:
            vt = (x1_coord - x_coord)/torch.clip((pred_time_till_t1),min=self.clip)
        # xshape = x.shape[1]
        final_vt = torch.cat([vt, torch.zeros_like(x[:,2:-1])], dim=1)
        return final_vt
    


### module ###
class RNNLSTM_Memory_Module(pl.LightningModule):
    def __init__(self, 
                 treatment_cond,
                 memory=3, # can increase / tune to see effect
                 dim=2, 
                 num_layers=50,
                 hidden_size=3, 
                 time_varying=True, 
                 conditional=True,
                 lr=1e-6,
                 sigma = 0.1, 
                 loss_fn = nn.functional.mse_loss,
                 metrics = ['mse_loss', 'l1_loss'],
                 implementation = "ODE", # can be SDE
                 sde_noise = 0.1,
                 clip = None,
                 naming = None,

                 ):
        super().__init__()
        self.model = RNN_LSTM_cond(dim=dim, 
                                    num_layers=num_layers,
                                    hidden_size=hidden_size,
                                    treatment_cond=treatment_cond,
                                    memory=memory,
                                    clip = clip)
        self.loss_fn = loss_fn
        self.save_hyperparameters()
        self.dim = dim
        # self.out_dim = out_dim
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.time_varying = time_varying
        self.conditional = conditional
        self.treatment_cond = treatment_cond
        self.lr = lr
        self.sigma = sigma
        self.metrics = metrics
        self.implementation = implementation
        self.memory = memory
        self.sde_noise = sde_noise
        self.clip = clip
        self.naming = "RNN_Cond_Module_"+implementation if naming is None else naming
        if self.memory > 1:
            self.naming += "_Memory_"+str(self.memory)
            
    def __convert_tensor__(self, tensor):
        return tensor.to(torch.float32)

    def __x_processing__(self, x0, x1, t0, t1):
        # squeeze xs (prevent mismatch)
        x0 = x0.squeeze(0)
        x1 = x1.squeeze(0)
        t0 = t0.squeeze()
        t1 = t1.squeeze()

        t = torch.rand(x0.shape[0],1).to(x0.device)
        mu_t = x0 * (1 - t) + x1 * t
        data_t_diff = (t1 - t0).unsqueeze(1)
        x = mu_t + self.sigma * torch.randn(x0.shape[0], self.dim).to(x0.device)

        ut = (x1 - x0) / (data_t_diff + 1e-4)
        t_model = t * data_t_diff + t0.unsqueeze(1)
        futuretime = t1 - t_model

        return x, ut, t_model, futuretime, t
    
    def training_step(self, batch, batch_idx):
        """_summary_

        Args:
            batch (list of output): x0_values, x0_classes, x1_values, times_x0, times_x1
            batch_idx (_type_): _description_

        Returns:
            _type_: _description_
        """
        x0, x0_class, x1, x0_time, x1_time = batch
        x0, x0_class, x1, x0_time, x1_time = self.__convert_tensor__(x0), self.__convert_tensor__(x0_class), self.__convert_tensor__(x1), self.__convert_tensor__(x0_time), self.__convert_tensor__(x1_time)


        x, ut, t_model, futuretime, t = self.__x_processing__(x0, x1, x0_time, x1_time)
        

        if len(x0_class.shape) == 3:
            x0_class = x0_class.squeeze(0)

        in_tensor = torch.cat([x,x0_class, t_model], dim = -1)
        xt = self.model.forward_train(in_tensor)

        # loss = torch.mean((xt[:,:2] - x1) ** 2) + torch.mean((xt[:,-1]-futuretime)**2)
        if self.implementation == "SDE":
            variance = t*(1-t)*self.sde_noise
            noise = torch.randn_like(xt[:,:2]) * torch.sqrt(variance)
            loss = self.loss_fn(xt[:,:2] + noise, x1) + self.loss_fn(xt[:,-1], futuretime)
        else:
            loss = self.loss_fn(xt[:,:2], x1) + self.loss_fn(xt[:,-1], futuretime)
        self.log('train_loss', loss)
        return loss

    def config_optimizer(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def validation_step(self, batch, batch_idx):
        """validation_step

        Args:
            batch (_type_): batch size of 1 (since uneven)
            batch_idx (_type_): _description_

        Returns:
            _type_: _description_
        """
        loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='val')
        self.log('val_loss', loss)
        for key, value in metricD.items():
            self.log(key+"_val", value)
        # return total_loss, traj_pairs
        return {'val_loss':loss, 'traj_pairs':pairs}

    def test_step(self, batch, batch_idx):
        loss, pairs, metricD = self.test_func_step(batch, batch_idx, mode='test')
        self.log('test_loss', loss)
        for key, value in metricD.items():
            self.log(key+"_test", value)
        # return total_loss, traj_pairs
        return {'test_loss':loss, 'traj_pairs':pairs}
    
    def test_func_step(self, batch, batch_idx, mode='none'):
        """assuming each is one patient/batch"""
        total_loss = []
        traj_pairs = []

        x0_values, x0_classes, x1_values, times_x0, times_x1 = batch
        times_x0 = times_x0.squeeze()
        times_x1 = times_x1.squeeze()

        # print(x0_values.shape)
        # print(x1_values.shape)
        full_traj = torch.cat([x0_values[0,0,:self.dim].unsqueeze(0), 
                               x1_values[0,:,:self.dim]], 
                               dim=0)
        full_time = torch.cat([times_x0[0].unsqueeze(0), times_x1], dim=0)
        ind_loss, pred_traj = self.test_trajectory(batch)
        total_loss.append(ind_loss)
        traj_pairs.append([full_traj, pred_traj])

        full_traj = full_traj.detach().cpu().numpy()
        pred_traj = pred_traj.detach().cpu().numpy()
        full_time = full_time.detach().cpu().numpy()

        # graph
        fig = plot_3d_path_ind(pred_traj, 
                               full_traj, 
                               t_span=full_time,
                               title="{}_trajectory_patient_{}".format(mode, batch_idx))
        if self.logger:
            # may cause problem if wandb disabled
            self.logger.experiment.log({"{}_trajectory_patient_{}".format(mode, batch_idx): wandb.Image(fig)})
        
        plt.close(fig)

        # metrics
        metricD = metrics_calculation(pred_traj, full_traj, metrics=self.metrics)
        return np.mean(total_loss), traj_pairs, metricD

    def test_trajectory(self,pt_tensor):
        if self.implementation == "ODE":
            return self.test_trajectory_ode(pt_tensor)
        elif self.implementation == "SDE":
            return self.test_trajectory_sde(pt_tensor)
    
    def test_trajectory_ode(self,pt_tensor):
        """test_trajectory

        Args:
            pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),


        Returns:
            mse_all, total_pred_tensor: _description_
        """
        node = NeuralODE(
            torch_wrapper_tv(self.model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        total_pred = []
        mse = []
        t_max = 0

        x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
        # squeeze all
        x0_values = x0_values.squeeze(0)
        x1_values = x1_values.squeeze(0)
        times_x0 = times_x0.squeeze()
        times_x1 = times_x1.squeeze()
        x0_classes = x0_classes.squeeze()

        if len(x0_classes.shape) == 1:
            x0_classes = x0_classes.unsqueeze(1)


        # print shape of all
        # print("shapes")
        # print(x0_values.shape)
        # print(x1_values.shape)
        # print(times_x0.shape)
        # print(times_x1.shape)
        # print(x0_classes.shape)

        total_pred.append(x0_values[0].unsqueeze(0))
        len_path = x0_values.shape[0]
        assert len_path == x1_values.shape[0]

        time_history = x0_classes[0][-(self.memory*self.dim):]

        for i in range(len_path): 
            # print(i)
            t_max = (times_x1[i]-times_x0[i])
            # time_span = self.__convert_tensor__(torch.linspace(0, t_max.squeeze(), 10)).to(x0_values.device)
            # new time span (use actual time) - makes more sense
            time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)

            new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1)
            with torch.no_grad():
                # get last pred, if none then use startpt
                if i == 0:
                    testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1)
                else: # incorporate last prediction
                    testpt = torch.cat([pred_traj, new_x_classes], dim=1)
                # print(testpt.shape)
                traj = node.trajectory(
                    # pt_tensor[i,:pt_tensor.shape[1]-1].unsqueeze(0),
                    testpt,
                    t_span=time_span, 
                )
            pred_traj = traj[-1,:,:self.dim]
            total_pred.append(pred_traj)
            ground_truth_coords = x1_values[i]
            mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())
            
            # history update
            flattened_coords = pred_traj.flatten()
            time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze()

        mse_all = np.mean(mse)
        total_pred_tensor = torch.stack(total_pred).squeeze(1)
        return mse_all, total_pred_tensor
    
    def configure_optimizers(self):
        # Define the optimizer
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    def test_trajectory_sde(self,pt_tensor):
        """test_trajectory

        Args:
            pt_tensor (numpy.array): (x0_values, x0_classes, x1_values, times_x0, times_x1),


        Returns:
            mse_all, total_pred_tensor: _description_
        """
        sde = SDE(self.model, noise=0.1)
        total_pred = []
        mse = []
        t_max = 0

        x0_values, x0_classes, x1_values, times_x0, times_x1 = pt_tensor
        # squeeze all
        x0_values = x0_values.squeeze(0)
        x1_values = x1_values.squeeze(0)
        times_x0 = times_x0.squeeze()
        times_x1 = times_x1.squeeze()
        x0_classes = x0_classes.squeeze()

        if len(x0_classes.shape) == 1:
            x0_classes = x0_classes.unsqueeze(1)



        total_pred.append(x0_values[0].unsqueeze(0))
        len_path = x0_values.shape[0]
        assert len_path == x1_values.shape[0]

        time_history = x0_classes[0][-(self.memory*self.dim):]

        for i in range(len_path): 
            # print(i)
            # t_max = (times_x1[i]-times_x0[i]) # calculate time difference (cumulative)

            time_span = self.__convert_tensor__(torch.linspace(times_x0[i], times_x1[i], 10)).to(x0_values.device)
            # encoded_time_span = positional_encoding_tensor(time_span).squeeze(1).reshape(-1, NUM_FREQS * 2)

            new_x_classes = torch.cat([x0_classes[i][:-(self.memory*self.dim)].unsqueeze(0), time_history.unsqueeze(0)], dim=1)
            with torch.no_grad():
                # get last pred, if none then use startpt
                if i == 0:
                    testpt = torch.cat([x0_values[i].unsqueeze(0),new_x_classes],dim=1)
                else: # incorporate last prediction
                    testpt = torch.cat([pred_traj, new_x_classes], dim=1)
                # print(testpt.shape)
                traj = self._sde_solver(sde, testpt, time_span)

            pred_traj = traj[-1,:,:self.dim]
            total_pred.append(pred_traj)
            ground_truth_coords = x1_values[i]
            mse.append(self.loss_fn(pred_traj, ground_truth_coords).detach().cpu().numpy())

            # history update
            flattened_coords = pred_traj.flatten()
            time_history = torch.cat([time_history[self.dim:].unsqueeze(0), flattened_coords.unsqueeze(0)], dim=1).squeeze()

        mse_all = np.mean(mse)
        total_pred_tensor = torch.stack(total_pred).squeeze(1)
        return mse_all, total_pred_tensor


    def _sde_solver(self, sde, initial_state, time_span):
        dt = time_span[1] - time_span[0]  # Time step
        current_state = initial_state
        trajectory = [current_state]

        for t in time_span[1:]:
            drift = sde.f(t, current_state)
            diffusion = sde.g(t, current_state)
            noise = torch.randn_like(current_state) * torch.sqrt(dt)
            current_state = current_state + drift * dt + diffusion * noise
            trajectory.append(current_state)

        return torch.stack(trajectory)