from data_classes import PICalibData
import torch 
import numpy as np
import pandas as pd
from utils import MeanStdevFilter, prepare_data
from build_ctxt import BuildContext
from typing import List, Tuple
import jax
import jax.numpy as jnp
import diffrax
import sys, os

class DataProcessor:
    def __init__(self, params):
        self.params = params
        self.input_filter = None
        self.output_filter = None
        self.input_dim = None
        self.output_dim = None
        self.train_horizon = self.params['train_horizon']
        self.val_horizon = self.params['val_horizon']
        self.interp_horizon = self.params['interp_horizon']
        self.past_ts_ctxt = self.params['past_ts_ctxt']
        self.init_cond_ctxt = self.params['init_cond_ctxt']
        self.past_feat_ctxt = self.params['past_feat_ctxt']
        self._build_ctxt = BuildContext(self.params)
        self.params['ctxt_dim'] = 0
 
    def get_data(self) -> np.ndarray:
        ################## Read Train Data ##################
        df = pd.read_csv(self.params['data_path'])

        grouped = df.groupby('horizon') 

        if self.params['ode_name'] == 'lorenz':
            pred_cols = ['x_pred','y_pred','z_pred']
            gr_cols = ['x_gr','y_gr','z_gr']   
        elif self.params['ode_name'] == 'glycolytic':    
            pred_cols = ['S1_pred','S2_pred','S3_pred','S4_pred','S5_pred','S6_pred','S7_pred']
            gr_cols = ['S1_gr','S2_gr','S3_gr','S4_gr','S5_gr','S6_gr','S7_gr']    
        elif self.params['ode_name'] == 'LVolt':    
            pred_cols = ['x_pred','y_pred']
            gr_cols = ['x_gr','y_gr']    
        elif self.params['ode_name'] == 'lorenz96':
            pred_cols = ['x1_pred','x2_pred','x3_pred','x4_pred','x5_pred']
            gr_cols = ['x1_gr','x2_gr','x3_gr','x4_gr','x5_gr']    
        elif self.params['ode_name'] == 'FHNag':
            pred_cols = ['v_pred','w_pred']
            gr_cols = ['v_gr','w_gr']   
        elif self.params['ode_name'] in ['walker2d','halfcheetah']:
            pred_cols = [f'S{dim}_pred' for dim in range(17)]
            gr_cols = [f'S{dim}_gr' for dim in range(17)]    
        elif self.params['ode_name'] == 'pen_expert':
            pred_cols = [f'S{dim}_pred' for dim in range(45)]
            gr_cols = [f'S{dim}_gr' for dim in range(45)]            
        elif self.params['ode_name'] == 'hopper':
            pred_cols = [f'S{dim}_pred' for dim in range(11)]
            gr_cols = [f'S{dim}_gr' for dim in range(11)]                  
        elif self.params['ode_name'] == 'hammer':
            pred_cols = [f'S{dim}_pred' for dim in range(46)]
            gr_cols = [f'S{dim}_gr' for dim in range(46)]                                                                                                          
        
        group_arrs = []    

        for _, group in grouped:

            group_arr = group[pred_cols+gr_cols].values
            group_arrs.append(np.expand_dims(group_arr, axis=0))     

        data_tr = np.concatenate(group_arrs, axis=0)  
        if self.params['data_type']:
            data_tr = data_tr.transpose(1,0,2) #[:,:50,:] 
        print(f"The Train data shape is: {data_tr.shape}")  

        ################## Read Val Data ##################
        train_data_name = self.params['dataset_name'].split("_")[:-1]
        val_data_name = ("_").join(train_data_name) + "_val" + ".csv"
        path_ = self.params['data_path'].split("/")[:-1]  
        path_ = ("/").join(path_)
        val_data_path = path_ + "/" + val_data_name
        self.params['iter_correction'] = int(train_data_name[-1][0])
        
        df = pd.read_csv(val_data_path)
        grouped = df.groupby('horizon')

        group_arrs = []    

        for _, group in grouped:
            group_arr = group[pred_cols+gr_cols].values
            group_arrs.append(np.expand_dims(group_arr, axis=0))  

        data_val = np.concatenate(group_arrs, axis=0)    
        if self.params['data_type']:
            data_val = data_val.transpose(1,0,2)
        print(f"The Val data shape is: {data_val.shape}")  

        self.input_dim = len(pred_cols) 
        self.output_dim = len(pred_cols)

        self.params['ctxt_dim'] += self.past_feat_ctxt*self.input_dim # input_dim+ctxt_dim
        if self.init_cond_ctxt:
            self.params['ctxt_dim'] += self.input_dim

        self.ctxt_dim = self.params['ctxt_dim']
        self.params['input_dim'] = self.input_dim
        self.params['output_dim'] = self.output_dim
        self.input_filter = MeanStdevFilter(self.ctxt_dim)
        self.output_filter = MeanStdevFilter(self.output_dim) 

        return (data_tr, data_val)

    def data_tuples(self, data: Tuple[np.ndarray]) -> Tuple[PICalibData]:

        data_tr, data_val = data
        ##################### Train data #####################

        sim_tr_data = data_tr[:,:,:self.input_dim]
        gr_tr_data = data_tr[:,:,self.input_dim:]

        sim_tr_input_data = sim_tr_data[:,:-1,:]
        sim_tr_output_data = sim_tr_data[:,1:,:]   

        gr_tr_input_data = gr_tr_data[:,:-1,:] 
        gr_tr_output_data = gr_tr_data[:,1:,:]             

        ##################### Val data #####################
        sim_val_data = data_val[:,:,:self.input_dim]
        gr_val_data = data_val[:,:,self.input_dim:]

        sim_val_input_data = sim_val_data[:,:-1,:]
        sim_val_output_data = sim_val_data[:,1:,:]   

        gr_val_input_data = gr_val_data[:,:-1,:] 
        gr_val_output_data = gr_val_data[:,1:,:]   


        ##################### Store data in dataclass #####################
        # PICalibData contains the un-normalized data here for all cases 
        calib_train_true = PICalibData(X=torch.Tensor(gr_tr_input_data),
                                       Y=torch.Tensor(gr_tr_output_data))
        
        calib_val_true = PICalibData(X=torch.Tensor(gr_val_input_data),
                                    Y=torch.Tensor(gr_val_output_data))
        
        end = calib_train_true.X.shape[1]*self.params['delta_t']
        
        calib_train_sim = PICalibData(X=torch.Tensor(sim_tr_input_data),
                                       Y=torch.Tensor(sim_tr_output_data),
                                       timesteps=torch.arange(0,end,self.params['delta_t']), # [:-1]
                                       error=torch.Tensor(gr_tr_output_data-sim_tr_output_data))
        end = calib_val_true.X.shape[1]*self.params['delta_t']

        calib_val_sim = PICalibData(X=torch.Tensor(sim_val_input_data),
                                    Y=torch.Tensor(sim_val_output_data),
                                    timesteps=torch.arange(0,end,self.params['delta_t']), # [:-1]
                                    error=torch.Tensor(gr_val_output_data-sim_val_output_data))
        
        return (calib_train_true, calib_val_true, calib_train_sim, calib_val_sim)   

    def normalize_calib_data(self, calib_data: Tuple[PICalibData]):

        train_len = calib_data[0].X_ctx.shape[0]
        val_len = calib_data[1].X_ctx.shape[0]

        gr_tr_input_ctxt = np.array(calib_data[0].X_ctx)
        gr_val_input_ctxt = np.array(calib_data[1].X_ctx)
        sim_tr_input_ctxt = np.array(calib_data[2].X_ctx)
        sim_val_input_ctxt = np.array(calib_data[3].X_ctx)
        
        self.calculate_mean_var(gr_tr_input_ctxt, self.ctxt_dim)
        gr_tr_ctxt_filter, sim_tr_ctxt_filter = self.normalize_data(gr_tr_input_ctxt, sim_tr_input_ctxt, self.ctxt_dim)
        gr_val_ctxt_filter, sim_val_ctxt_filter = self.normalize_data(gr_val_input_ctxt, sim_val_input_ctxt, self.ctxt_dim)


        gr_tr_ctxt_filter = torch.Tensor(gr_tr_ctxt_filter)
        sim_tr_ctxt_filter = torch.Tensor(sim_tr_ctxt_filter)

        gr_val_ctxt_filter = torch.Tensor(gr_val_ctxt_filter)
        sim_val_ctxt_filter = torch.Tensor(sim_val_ctxt_filter)

        # saving X_ctxt as normalized 
        calib_data[0].X_ctx, calib_data[1].X_ctx = gr_tr_ctxt_filter, gr_val_ctxt_filter
        calib_data[2].X_ctx, calib_data[3].X_ctx = sim_tr_ctxt_filter, sim_val_ctxt_filter

        # saving X as normalized 
        # If we have past_feat_ctxt > 1, then we need to change the slicing here
        calib_data[0].X, calib_data[1].X = gr_tr_ctxt_filter[:,:,:self.input_dim], gr_val_ctxt_filter[:,:,:self.input_dim]
        calib_data[2].X, calib_data[3].X = sim_tr_ctxt_filter[:,:,:self.input_dim], sim_val_ctxt_filter[:,:,:self.input_dim] 
        
        ############## Errors normalization ##############
        error_tr = calib_data[2].error
        error_val = calib_data[3].error
        
        error_tr = np.array(error_tr)
        error_val = np.array(error_val)

        self.calculate_mean_var_out(error_tr, self.output_dim)
        error_tr_filter = self.normalize_data_out(error_tr, self.output_dim)
        error_val_filter = self.normalize_data_out(error_val, self.output_dim)

        error_tr_filter = torch.Tensor(error_tr_filter)
        error_val_filter = torch.Tensor(error_val_filter)

        calib_data[2].error, calib_data[3].error = error_tr_filter, error_val_filter

        return calib_data     

    def calculate_mean_var_out(self, errors: np.ndarray, output_dim: int) -> None:

        errors = errors.reshape(-1, output_dim)

        total_points = errors.shape[0]

        for i in range(total_points):
            self.output_filter.update(errors[i,:])

        self.params['output_filter'] = self.output_filter  

        return    

    def normalize_data_out(self, errors: np.ndarray, output_dim: int) -> Tuple[np.ndarray]:

        errors_seq_len = errors.shape[1]    
        errors = errors.reshape(-1, output_dim)

        out_filter = prepare_data(errors, self.output_filter)

        out_filter = out_filter.reshape(-1, errors_seq_len, output_dim)

        return out_filter    

    def calculate_mean_var(self, gr_input_data: np.ndarray, input_dim: int) -> None:

        gr_input_data = gr_input_data.reshape(-1, input_dim)

        total_points = gr_input_data.shape[0]

        for i in range(total_points):
            self.input_filter.update(gr_input_data[i,:])

        self.params['input_filter'] = self.input_filter  

        return     

    def normalize_data(self, gr_input_data: np.ndarray, \
                       sim_input_data: np.ndarray, input_dim: int) -> Tuple[np.ndarray]:
        
        gr_seq_len = gr_input_data.shape[1]
        sim_seq_len = sim_input_data.shape[1]
        
        gr_input_data = gr_input_data.reshape(-1, input_dim)
        sim_input_data = sim_input_data.reshape(-1, input_dim)

        gr_input_filter = prepare_data(gr_input_data, self.input_filter)
        sim_input_filter = prepare_data(sim_input_data, self.input_filter)

        gr_input_filter = gr_input_filter.reshape(-1, gr_seq_len, input_dim)
        sim_input_filter = sim_input_filter.reshape(-1, sim_seq_len, input_dim)

        return (gr_input_filter, sim_input_filter)  

    def torch_to_jax(self, calib_data: Tuple[PICalibData]):

        for i in range(2):
            
            calib_data[i].X = jnp.array(calib_data[i].X.detach().cpu().numpy())
            calib_data[i].Y = jnp.array(calib_data[i].Y.detach().cpu().numpy())
            calib_data[i].X_ctx = jnp.array(calib_data[i].X_ctx.detach().cpu().numpy())

        for i in range(2,4):
            
            calib_data[i].X = jnp.array(calib_data[i].X.detach().cpu().numpy())
            calib_data[i].Y = jnp.array(calib_data[i].Y.detach().cpu().numpy())
            calib_data[i].X_ctx = jnp.array(calib_data[i].X_ctx.detach().cpu().numpy())
            calib_data[i].timesteps = jnp.array(calib_data[i].timesteps.detach().cpu().numpy())
            calib_data[i].error = jnp.array(calib_data[i].error.detach().cpu().numpy())

        return calib_data        

    def spline_coeffs(self, calib_data: Tuple[PICalibData]):
        """Use cubic spline to calculate coeffecients"""
        train_ts = jnp.repeat(calib_data[2].timesteps[None, :], repeats=calib_data[2].X.shape[0], axis=0)
        val_ts = jnp.repeat(calib_data[3].timesteps[None, :], repeats=calib_data[3].X.shape[0], axis=0)

        # Tuple of 4 coeffecients of cubic splines (400, 297, 7)
        calib_data[0].X_ctx_coeffs = \
            jax.vmap(diffrax.backward_hermite_coefficients)(train_ts,calib_data[0].X_ctx) 
        calib_data[1].X_ctx_coeffs = \
            jax.vmap(diffrax.backward_hermite_coefficients)(val_ts,calib_data[1].X_ctx)
        calib_data[2].X_ctx_coeffs = \
            jax.vmap(diffrax.backward_hermite_coefficients)(train_ts,calib_data[2].X_ctx)
        calib_data[3].X_ctx_coeffs = \
            jax.vmap(diffrax.backward_hermite_coefficients)(val_ts,calib_data[3].X_ctx)
        
        # select 
        
        return calib_data
        
    def no_ctxt(self, calib_data: Tuple[PICalibData]):
        """
        In case of no context, simply copy the calib_data[0].X to
        calib_data[0].X_ctxt 
        """
        calib_data[0].X_ctx = calib_data[0].X.clone()
        calib_data[1].X_ctx = calib_data[1].X.clone()
        calib_data[2].X_ctx = calib_data[2].X.clone()
        calib_data[3].X_ctx = calib_data[3].X.clone()

        return calib_data     

    def build_context(self, calib_data: Tuple[PICalibData], add_time: bool):

        # CAUTION: calib_data gets modified in-place so be careful sending it multiple times to self._build_ctxt
        # CAUTION: DON'T MODIFY X ENTRY OF PICalibData TUPLE AS IT IS USEFUL FOR ADDING INITIAL CONDITION CONTEXT

        if not add_time:
            if calib_data[0].X_ctx is None: # True if self.past_feat_ctxt == 1
                calib_data = self.no_ctxt(calib_data)
                print("X has been copied to X_ctxt!") 

            if self.init_cond_ctxt:
                calib_data = self._build_ctxt.add_init_cond(calib_data)
            else:
                print("Initial condition won't be used as a context!")  
        else:        
            if self.past_ts_ctxt == 1:
                calib_data = self._build_ctxt.add_timestep(calib_data)
            elif self.past_ts_ctxt > 1: 
                calib_data = self._build_ctxt.add_k_timestep(calib_data)
            else: 
                print("The past timesteps will NOT be used as context!")  

        print(f"The context length is: {calib_data[0].X_ctx.shape[-1]}")    

        return calib_data  