from data_classes import PICalibData
import torch 
import numpy as np
import pandas as pd
from utils import MeanStdevFilter, prepare_data
from build_ctxt import BuildContext
from typing import List, Tuple
import jax
import jax.numpy as jnp
import diffrax
import sys, os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class DataProcessor:
    def __init__(self, params):
        self.params = params
        self.input_filter = None
        self.output_filter = None
        self.input_dim = None
        self.output_dim = None
        self.calib_horizon = params['calib_horizon']
        self.past_ts_ctxt = self.params['past_ts_ctxt']
        self.init_cond_ctxt = self.params['init_cond_ctxt']
        self.past_feat_ctxt = self.params['past_feat_ctxt']
        self._build_ctxt = BuildContext(self.params)
        self.params['ctxt_dim'] = 0
        # Don't want to normalize time dimension for spline coeffs calculation
        #if self.past_ts_ctxt > 0: 
        #    self.params['ctxt_dim'] = self.past_ts_ctxt # input_dim will be added in get_data method   

    def get_data(self) -> np.ndarray:
        df = pd.read_csv(self.params['data_path'])

        if self.params['data_type']:
            grouped = df.groupby('horizon') 
        else:
            grouped = df.groupby('pred_dt') 

        if self.params['ode_name'] == 'lorenz':
            #pred_cols = ['x_pred','y_pred','z_pred','x_dot_pred','y_dot_pred','z_dot_pred']
            #gr_cols = ['x_gr','y_gr','z_gr','x_dot_gr','y_dot_gr','z_dot_gr']
            pred_cols = ['x_pred','y_pred','z_pred']
            gr_cols = ['x_gr','y_gr','z_gr']   
        elif self.params['ode_name'] == 'gmat':
            #pred_cols = ['x_pred','y_pred','z_pred','x_dot_pred','y_dot_pred','z_dot_pred']
            #gr_cols = ['x_gr','y_gr','z_gr','x_dot_gr','y_dot_gr','z_dot_gr']            
            pred_cols = ['x_pred','y_pred','z_pred']
            gr_cols = ['x_gr','y_gr','z_gr']  
        elif self.params['ode_name'] == 'glycolytic':
            pred_cols = ['S1_pred','S2_pred','S3_pred','S4_pred','S5_pred','S6_pred','S7_pred']
            gr_cols = ['S1_gr','S2_gr','S3_gr','S4_gr','S5_gr','S6_gr','S7_gr']       
        elif self.params['ode_name'] == 'LVolt':
            pred_cols = ['x_pred','y_pred']
            gr_cols = ['x_gr','y_gr']   
        elif self.params['ode_name'] == 'lorenz96':
            pred_cols = ['x1_pred','x2_pred','x3_pred','x4_pred','x5_pred']
            gr_cols = ['x1_gr','x2_gr','x3_gr','x4_gr','x5_gr']    
        elif self.params['ode_name'] == 'FHNag':
            pred_cols = ['v_pred','w_pred']
            gr_cols = ['v_gr','w_gr']                                                                         

        group_arrs = []    

        for _, group in grouped:
            group_arr = group[pred_cols+gr_cols].values
            group_arrs.append(np.expand_dims(group_arr, axis=0))    

        if self.params['ode_name'] == 'gmat':
            self.traj_len = [group_arr.shape[1] for group_arr in group_arrs]
            max_horizon = max([group_arr.shape[1] for group_arr in group_arrs])

            padded_group_arrs = []
            for i in range(len(group_arrs)):
                if max_horizon == self.traj_len[i]:
                    padded_arr = np.full((1, max_horizon-group_arrs[i].shape[1], group_arrs[i].shape[2]), 0) # 755899
                    padded_group_arrs.append(np.concatenate([group_arrs[i], padded_arr], axis=1))

            print(f"The Max Horizon is: {max_horizon}")
            print(f"The number of Traj are: {len(group_arrs)}")
        
        if self.params['ode_name'] == 'gmat':
            data = np.concatenate(padded_group_arrs, axis=0)
            rand_perm = np.random.permutation(data.shape[0])
            data = data[rand_perm]
            print(f"The max value in data is: {np.max(data)}")
        else:
            data = np.concatenate(group_arrs, axis=0)

        if self.params['data_type']:
            data = data.transpose(1,0,2)
        print(f"The data shape is: {data.shape}")  
        
        self.input_dim = len(pred_cols) # Using this to store norm data in PICalibData
        self.output_dim = len(pred_cols)

        self.params['ctxt_dim'] += self.past_feat_ctxt*self.input_dim # input_dim+ctxt_dim
        if self.init_cond_ctxt:
            self.params['ctxt_dim'] += self.input_dim

        self.ctxt_dim = self.params['ctxt_dim']
        self.params['input_dim'] = self.input_dim
        self.params['output_dim'] = self.output_dim
        self.input_filter = MeanStdevFilter(self.ctxt_dim)
        self.output_filter = MeanStdevFilter(self.output_dim) 

        return data

    def normalize_calib_data(self, calib_data: Tuple[PICalibData]):

        train_len = calib_data[0].X_ctx.shape[0]
        val_len = calib_data[1].X_ctx.shape[0]

        gr_input_ctxt = torch.cat([calib_data[0].X_ctx, calib_data[1].X_ctx], dim=0)
        gr_input_ctxt = np.array(gr_input_ctxt)

        sim_input_ctxt = torch.cat([calib_data[2].X_ctx, calib_data[3].X_ctx], dim=0)
        sim_input_ctxt = np.array(sim_input_ctxt)
        
        self.calculate_mean_var(gr_input_ctxt, self.ctxt_dim)
        gr_ctxt_filter, sim_ctxt_filter = self.normalize_data(gr_input_ctxt, sim_input_ctxt, self.ctxt_dim)

        gr_ctxt_filter = torch.Tensor(gr_ctxt_filter)
        sim_ctxt_filter = torch.Tensor(sim_ctxt_filter)

        # saving X_ctxt as normalized 
        calib_data[0].X_ctx, calib_data[1].X_ctx = gr_ctxt_filter[:train_len], gr_ctxt_filter[-val_len:]
        calib_data[2].X_ctx, calib_data[3].X_ctx = sim_ctxt_filter[:train_len], sim_ctxt_filter[-val_len:]

        # saving X as normalized 
        # If we have past_feat_ctxt > 1, then we need to change the slicing here
        calib_data[0].X, calib_data[1].X = gr_ctxt_filter[:train_len,:,:self.input_dim], gr_ctxt_filter[-val_len:,:,:self.input_dim]
        calib_data[2].X, calib_data[3].X = sim_ctxt_filter[:train_len,:,:self.input_dim], sim_ctxt_filter[-val_len:,:,:self.input_dim] 

        ############## Errors normalization ##############
        errors = torch.cat([calib_data[2].error, calib_data[3].error], dim=0)
        errors = np.array(errors)
        self.calculate_mean_var_out(errors, self.output_dim)
        error_filter = self.normalize_data_out(errors, self.output_dim)

        error_filter = torch.Tensor(error_filter)

        calib_data[2].error, calib_data[3].error = error_filter[:train_len], error_filter[-val_len:]

        return calib_data      

    def calculate_mean_var(self, gr_input_data: np.ndarray, input_dim: int) -> None:

        gr_input_data = gr_input_data.reshape(-1, input_dim)

        total_points = gr_input_data.shape[0]

        for i in range(total_points):
            self.input_filter.update(gr_input_data[i,:])

        self.params['input_filter'] = self.input_filter  

        return     
    
    def normalize_data(self, gr_input_data: np.ndarray, \
                       sim_input_data: np.ndarray, input_dim: int) -> Tuple[np.ndarray]:
        
        gr_input_data = gr_input_data.reshape(-1, input_dim)
        sim_input_data = sim_input_data.reshape(-1, input_dim)

        gr_input_filter = prepare_data(gr_input_data, self.input_filter)
        sim_input_filter = prepare_data(sim_input_data, self.input_filter)

        gr_input_filter = gr_input_filter.reshape(-1, self.calib_horizon-1, input_dim)
        sim_input_filter = sim_input_filter.reshape(-1, self.calib_horizon-1, input_dim)

        return (gr_input_filter, sim_input_filter)    

    def calculate_mean_var_out(self, errors: np.ndarray, output_dim: int) -> None:

        errors = errors.reshape(-1, output_dim)

        total_points = errors.shape[0]

        for i in range(total_points):
            self.output_filter.update(errors[i,:])

        self.params['output_filter'] = self.output_filter  

        return    

    def normalize_data_out(self, errors: np.ndarray, output_dim: int) -> Tuple[np.ndarray]:
        
        errors = errors.reshape(-1, output_dim)

        out_filter = prepare_data(errors, self.output_filter)

        out_filter = out_filter.reshape(-1, self.calib_horizon-1, output_dim)

        return out_filter

    def normalize_var_length_data(self, calib_data: Tuple[PICalibData]):

        train_len = calib_data[0].X_ctx.shape[0]
        val_len = calib_data[1].X_ctx.shape[0]

        gr_input_ctxt = torch.cat([calib_data[0].X_ctx, calib_data[1].X_ctx], dim=0)
        gr_input_ctxt = jnp.array(gr_input_ctxt)

        sim_input_ctxt = torch.cat([calib_data[2].X_ctx, calib_data[3].X_ctx], dim=0)
        sim_input_ctxt = jnp.array(sim_input_ctxt)
        
        gr_ctxt_filter, sim_ctxt_filter = self.normalize_var_len_input(gr_input_ctxt, sim_input_ctxt, self.ctxt_dim)
        #gr_ctxt_filter, sim_ctxt_filter = self.normalize_data(gr_input_ctxt, sim_input_ctxt, self.ctxt_dim)

        gr_ctxt_filter = torch.Tensor(gr_ctxt_filter)
        sim_ctxt_filter = torch.Tensor(sim_ctxt_filter)

        # saving X_ctxt as normalized 
        calib_data[0].X_ctx, calib_data[1].X_ctx = gr_ctxt_filter[:train_len], gr_ctxt_filter[-val_len:]
        calib_data[2].X_ctx, calib_data[3].X_ctx = sim_ctxt_filter[:train_len], sim_ctxt_filter[-val_len:]

        # saving X as normalized 
        # If we have past_feat_ctxt > 1, then we need to change the slicing here
        calib_data[0].X, calib_data[1].X = gr_ctxt_filter[:train_len,:,:self.input_dim], gr_ctxt_filter[-val_len:,:,:self.input_dim]
        calib_data[2].X, calib_data[3].X = sim_ctxt_filter[:train_len,:,:self.input_dim], sim_ctxt_filter[-val_len:,:,:self.input_dim] 

        ############## Errors normalization ##############
        errors = torch.cat([calib_data[2].error, calib_data[3].error], dim=0)
        errors = jnp.array(errors)
        error_filter = self.normalize_var_len_output(errors, self.output_dim)
        #error_filter = self.normalize_data_out(errors, self.output_dim)

        error_filter = torch.Tensor(error_filter)
        calib_data[2].error, calib_data[3].error = error_filter[:train_len], error_filter[-val_len:]

        return calib_data   

    def normalize_var_len_input(self, gr_input_ctxt: jax.Array, \
                                            sim_input_ctxt: jax.Array, \
                                            input_dim: int) -> None:
        seq_len = gr_input_ctxt.shape[1]
        gr_input_ctxt = gr_input_ctxt.reshape(-1, input_dim)
        sim_input_ctxt = sim_input_ctxt.reshape(-1, input_dim)

        means_in = jax.vmap(jnp.nanmean, in_axes=(1))(gr_input_ctxt)
        std_in = jax.vmap(jnp.nanstd, in_axes=(1))(gr_input_ctxt)
        self.params['input_filter_var_length'] = {'input_var_len_mean': means_in, 
                                                  'input_var_len_std': std_in}  
        
        gr_ctxt_filter = (gr_input_ctxt - means_in) / (std_in + 1e-6)
        sim_ctxt_filter = (sim_input_ctxt - means_in) / (std_in + 1e-6)

        gr_ctxt_filter = gr_ctxt_filter.reshape(-1,seq_len,input_dim)
        sim_ctxt_filter = sim_ctxt_filter.reshape(-1,seq_len,input_dim)

        return np.array(gr_ctxt_filter), np.array(sim_ctxt_filter)   

    def normalize_var_len_output(self, errors: jax.Array, \
                                        output_dim: int) -> None:
        seq_len = errors.shape[1]
        errors = errors.reshape(-1, output_dim)

        means_in = jax.vmap(jnp.nanmean, in_axes=(1))(errors)
        std_in = jax.vmap(jnp.nanstd, in_axes=(1))(errors)

        self.params['output_filter_var_length'] = {'output_var_len_mean': means_in, 
                                                  'output_var_len_std': std_in}  
        
        errors_filter = (errors - means_in) / (std_in + 1e-6)
        errors_filter = errors_filter.reshape(-1,seq_len,output_dim)

        return np.array(errors_filter)  
        
    def data_tuples(self, data: np.ndarray) -> Tuple[PICalibData]:

        sim_data = data[:,:,:self.input_dim]
        gr_tr_data = data[:,:,self.input_dim:]

        # dropping one point from seq due to shifting (calib_horizon-1) is seq_len now
        # TODO: Check this step below for ode_data (VALID or not)
        gr_input_data = gr_tr_data[:,:-1,:] 
        gr_output_data = gr_tr_data[:,1:,:]
        self.params['norm_loss_wghts'] = \
            torch.FloatTensor(1 / np.var(np.array(gr_output_data.reshape(-1,self.output_dim)), axis=0)).to(device)
        #[4.0810e-08, 1.3870e-06, 3.9746e-08, 3.6635e-02, 1.2449e+00, 3.5690e-02]

        sim_input_data = sim_data[:,:-1,:]
        sim_output_data = sim_data[:,1:,:]

        # get normalized ground truth and simulated input data
        #self.calculate_mean_var(gr_input_data)
        #input_data = self.normalize_data(gr_input_data, sim_input_data)

        #gr_input_filter, sim_input_filter = input_data
        if self.params['data_type']:
            train_len = int(0.80*gr_input_data.shape[0]) # 0.80
        else:
            train_len = int(0.90*gr_input_data.shape[0]) 
            self.tr_len = self.traj_len[:train_len]
            self.va_len = self.traj_len[train_len:][1:] # see TODO: remove -1 later

        val_len = gr_input_data.shape[0] - train_len - 1 # TODO: remove -1 later
        #val_len = 295
        #print(train_len)
        #print(gr_output_data[-val_len,:,0])
        #print(val_len)
        #print((gr_output_data[-val_len:]-sim_output_data[-val_len:])[0,:20,0])
        #sys.exit()        
        print(f"Training seq. are: {train_len}")
        print(f"Validation seq. are: {val_len}")

        if self.params['data_type']:
            mask = np.random.rand(train_len+val_len)
            self.traj_len = np.random.rand(train_len+val_len)
        else:
            mask = self.create_mask()

        # PICalibData contains the un-normalized data here for all cases 
        calib_train_true = PICalibData(X=torch.Tensor(gr_input_data[:train_len]),
                                       Y=torch.Tensor(gr_output_data[:train_len]),
                                       mask=torch.Tensor(mask[:train_len]),
                                       Traj_len=torch.Tensor(self.traj_len[:train_len]))
        
        calib_val_true = PICalibData(X=torch.Tensor(gr_input_data[-val_len:]),
                                    Y=torch.Tensor(gr_output_data[-val_len:]),
                                    mask=torch.Tensor(mask[-val_len:]),
                                    Traj_len=torch.Tensor(self.traj_len[-val_len:]))
        end = calib_train_true.X.shape[1]*self.params['delta_t']
        
        calib_train_sim = PICalibData(X=torch.Tensor(sim_input_data[:train_len]),
                                       Y=torch.Tensor(sim_output_data[:train_len]),
                                       timesteps=torch.arange(0,end,self.params['delta_t']), # [:-1]
                                       error=torch.Tensor(gr_output_data[:train_len]-sim_output_data[:train_len]))
        
        calib_val_sim = PICalibData(X=torch.Tensor(sim_input_data[-val_len:]),
                                    Y=torch.Tensor(sim_output_data[-val_len:]),
                                    timesteps=torch.arange(0,end,self.params['delta_t']), # [:-1]
                                    error=torch.Tensor(gr_output_data[-val_len:]-sim_output_data[-val_len:]))
        
        return (calib_train_true, calib_val_true, calib_train_sim, calib_val_sim) 

    def create_mask(self,):
        
        max_len = int(np.max(np.array(self.traj_len)))  
        indices = np.arange(max_len)
        broadcast = np.array(self.traj_len)-1
        mask = indices < broadcast[:, None]
        if self.params['bayesian']:
            # both for mean and variance
            mask = np.repeat(mask[:,:,np.newaxis], repeats=self.output_dim*2, axis=2)
        else:    
            mask = np.repeat(mask[:,:,np.newaxis], repeats=self.output_dim, axis=2)
        return mask

    def no_ctxt(self, calib_data: Tuple[PICalibData]):
        """
        In case of no context, simply copy the calib_data[0].X to
        calib_data[0].X_ctxt 
        """
        calib_data[0].X_ctx = calib_data[0].X.clone()
        calib_data[1].X_ctx = calib_data[1].X.clone()
        calib_data[2].X_ctx = calib_data[2].X.clone()
        calib_data[3].X_ctx = calib_data[3].X.clone()

        return calib_data    
    
    def spline_coeffs(self, calib_data: Tuple[PICalibData]):
        """Use cubic spline to calculate coeffecients"""
        train_ts = jnp.repeat(calib_data[2].timesteps[None, :], repeats=calib_data[2].X.shape[0], axis=0)
        val_ts = jnp.repeat(calib_data[3].timesteps[None, :], repeats=calib_data[3].X.shape[0], axis=0)

        # Tuple of 4 coeffecients of cubic splines (400, 297, 7)
        calib_data[0].X_ctx_coeffs = \
            jax.vmap(diffrax.backward_hermite_coefficients)(train_ts,calib_data[0].X_ctx) 
        calib_data[1].X_ctx_coeffs = \
            jax.vmap(diffrax.backward_hermite_coefficients)(val_ts,calib_data[1].X_ctx)
        calib_data[2].X_ctx_coeffs = \
            jax.vmap(diffrax.backward_hermite_coefficients)(train_ts,calib_data[2].X_ctx)
        calib_data[3].X_ctx_coeffs = \
            jax.vmap(diffrax.backward_hermite_coefficients)(val_ts,calib_data[3].X_ctx)
        
        return calib_data
    
    def torch_to_jax(self, calib_data: Tuple[PICalibData]):

        for i in range(2):
            
            calib_data[i].X = jnp.array(calib_data[i].X.detach().cpu().numpy())
            calib_data[i].Y = jnp.array(calib_data[i].Y.detach().cpu().numpy())
            calib_data[i].X_ctx = jnp.array(calib_data[i].X_ctx.detach().cpu().numpy())
            calib_data[i].Traj_len = jnp.array(calib_data[i].Traj_len.detach().cpu().numpy())
            calib_data[i].mask = jnp.array(calib_data[i].mask.detach().cpu().numpy())

        for i in range(2,4):
            
            calib_data[i].X = jnp.array(calib_data[i].X.detach().cpu().numpy())
            calib_data[i].Y = jnp.array(calib_data[i].Y.detach().cpu().numpy())
            calib_data[i].X_ctx = jnp.array(calib_data[i].X_ctx.detach().cpu().numpy())
            calib_data[i].timesteps = jnp.array(calib_data[i].timesteps.detach().cpu().numpy())
            calib_data[i].error = jnp.array(calib_data[i].error.detach().cpu().numpy())

        return calib_data    


    def build_context(self, calib_data: Tuple[PICalibData], add_time: bool):

        # CAUTION: calib_data gets modified in-place so be careful sending it multiple times to self._build_ctxt
        # CAUTION: DON'T MODIFY X ENTRY OF PICalibData TUPLE AS IT IS USEFUL FOR ADDING INITIAL CONDITION CONTEXT

        if not add_time:
            if calib_data[0].X_ctx is None: # True if self.past_feat_ctxt == 1
                calib_data = self.no_ctxt(calib_data)
                print("X has been copied to X_ctxt!") 

            if self.init_cond_ctxt:
                calib_data = self._build_ctxt.add_init_cond(calib_data)
            else:
                print("Initial condition won't be used as a context!")  
        else:        
            if self.past_ts_ctxt == 1:
                calib_data = self._build_ctxt.add_timestep(calib_data)
            elif self.past_ts_ctxt > 1: 
                calib_data = self._build_ctxt.add_k_timestep(calib_data)
            else: 
                print("The past timesteps will NOT be used as context!")  

        print(f"The context length is: {calib_data[0].X_ctx.shape[-1]}")    

        return calib_data      