from dataclasses import dataclass
import torch 
import numpy as np
from utils import MeanStdevFilter, prepare_data, PICalibData
from typing import List, Tuple, Dict
import pickle 
import jax
import jax.numpy as jnp
import jax.random as jr
import sys

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class DataProcessor:
    def __init__(self, params):
        self.params: Dict = params
        self.input_filter: MeanStdevFilter = None
        self.input_dim: int = None
        self.output_dim: int = None
        self.output_filter: MeanStdevFilter = None
        self.train_horizon: int = self.params['train_horizon']
        self.val_horizon: int = self.params['val_horizon']
        self.interp_horizon: int = self.params['interp_horizon']

    def get_data(self, load_model) -> PICalibData:
        # read data
        data_dicts = pickle.load(open(f"{self.params['dataset_name']}.pkl", 'rb'))
        if self.params['ode_name'] == 'lorenz':
            #data = np.concatenate([np.concatenate([data_dict["x"], data_dict['y'], data_dict['z'],data_dict["xdot"], data_dict['ydot'], data_dict['zdot']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
            data = np.concatenate([np.concatenate([data_dict["x"], data_dict['y'], data_dict['z']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
        elif self.params['ode_name'] == 'glycolytic':    
            #data = np.concatenate([np.concatenate([data_dict["S1"], data_dict['S2'], data_dict['S3'],data_dict["S4"], data_dict['S5'], data_dict['S6'], data_dict['S7'], data_dict["dS1dt"], data_dict['dS2dt'], data_dict['dS3dt'],data_dict["dS4dt"], data_dict['dS5dt'], data_dict['dS6dt'], data_dict['dS7dt']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
            data = np.concatenate([np.concatenate([data_dict["S1"], data_dict['S2'], data_dict['S3'],data_dict["S4"], data_dict['S5'], data_dict['S6'], data_dict['S7']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
        elif self.params['ode_name'] == 'LVolt': 
            data = np.concatenate([np.concatenate([data_dict["x"], data_dict['y'],data_dict["xdot"], data_dict['ydot']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
            #data = np.concatenate([np.concatenate([data_dict["x"], data_dict['y']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)  
        elif self.params['ode_name'] == 'lorenz96': 
            data = np.concatenate([np.concatenate([data_dict["X1"], data_dict['X2'],data_dict["X3"], data_dict['X4'], data_dict['X5']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0) 
        elif self.params['ode_name'] == 'FHNag':
            data = np.concatenate([np.concatenate([data_dict["v"], data_dict['w']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)     
        # input and output dim
        self.input_dim = data.shape[-1]
        self.output_dim = data.shape[-1]
        self.params['input_dim'] = self.input_dim
        self.params['output_dim'] = self.output_dim

        # prepare deltas for training
        X_prev = data[:,:-1,:]
        X_next = data[:,1:,:]

        # include 80% of trajectories for training
        total_traj = data.shape[0]
        train_sz = int(total_traj * self.params["train_val_ratio"])
        val_sz = total_traj - train_sz
        rand_perm = torch.randperm(total_traj)

        train_idx = rand_perm[:train_sz]
        val_idx = rand_perm[train_sz:]

        # train and val data
        train_X_traj = X_prev[train_idx,:self.val_horizon,:].copy()
        train_Y_traj = X_next[train_idx,:self.val_horizon,:].copy()

        val_X_traj = X_prev[val_idx,:self.val_horizon,:].copy()
        val_Y_traj = X_next[val_idx,:self.val_horizon,:].copy()

        train_X = train_X_traj.reshape(-1, self.input_dim).copy()
        #val_X = val_X_traj.reshape(-1, self.input_dim).copy()

        # for Variance Normalized loss
        #self.params['norm_loss_wghts'] = torch.FloatTensor(1 / np.var(train_delta, axis=0)).to(device) 

        if not load_model:
            self.input_filter = MeanStdevFilter(self.input_dim) 

            self.calculate_mean_var(train_X)
            norm_train_X, norm_train_Y, norm_val_X, norm_val_Y = \
                    self.prepare_datapoints(train_X_traj, train_Y_traj, val_X_traj, val_Y_traj)
            dataset = self.prepare_dataclass(norm_train_X, norm_train_Y, norm_val_X, norm_val_Y)
            #if self.params['n_sample']:
            #    dataset = self.irregular_samp(dataset)

            return dataset
        else:
            return None
    
    def prepare_datapoints(self, train_X_traj: np.ndarray, train_Y_traj: np.ndarray,\
                                val_X_traj: np.ndarray, val_Y_traj: np.ndarray):
        
        seq_len_train = train_X_traj.shape[1]
        seq_len_val = val_X_traj.shape[1]

        train_X = train_X_traj.reshape(-1, self.input_dim)
        train_Y = train_Y_traj.reshape(-1, self.input_dim)
        val_X = val_X_traj.reshape(-1, self.input_dim)
        val_Y = val_Y_traj.reshape(-1, self.input_dim)
        
        norm_train_X = prepare_data(train_X, self.input_filter)
        norm_train_Y = prepare_data(train_Y, self.input_filter)  
        norm_val_X = prepare_data(val_X, self.input_filter)
        norm_val_Y = prepare_data(val_Y, self.input_filter)  

        norm_train_X = norm_train_X.reshape(-1,seq_len_train,self.input_dim)
        norm_train_Y = norm_train_Y.reshape(-1,seq_len_train,self.input_dim)
        norm_val_X = norm_val_X.reshape(-1,seq_len_val,self.input_dim)
        norm_val_Y = norm_val_Y.reshape(-1,seq_len_val,self.input_dim)

        return norm_train_X, norm_train_Y, norm_val_X, norm_val_Y
       
    
    def calculate_mean_var(self, input_data: np.ndarray) -> None:

        total_points = input_data.shape[0]

        for i in range(total_points):
            self.input_filter.update(input_data[i,:])
        
        self.params['input_filter'] = self.input_filter
        #for i in range(total_points):
        #    self.output_filter.update(output_data[i,:])            

        return 
    
    def irregular_samp(self, data: PICalibData):
        """create irregularly sampled time series"""
        tr_steps = data.timesteps_train.shape[0]
        tr_steps_ = data.timesteps_train_.shape[0]

        key=jr.PRNGKey(self.params['seed'])
        tr_ran = jnp.sort(jax.random.permutation(key, tr_steps)[:int(self.params['n_sample']*tr_steps)])
        tr_ran_ = jnp.sort(jax.random.permutation(key, tr_steps_)[:int(self.params['n_sample']*tr_steps_)])

        data.norm_train_X = data.norm_train_X[:,tr_ran,:]
        data.norm_train_Y = data.norm_train_Y[:,tr_ran,:]
        data.norm_train_X_ = data.norm_train_X_[:,tr_ran_,:]
        data.norm_train_Y_ = data.norm_train_Y_[:,tr_ran_,:]
        data.timesteps_train = data.timesteps_train[tr_ran]
        data.timesteps_train_ = data.timesteps_train_[tr_ran_]

        return data
    
    def prepare_dataclass(self, norm_train_X: np.ndarray, norm_train_Y: np.ndarray,\
                           norm_val_X: np.ndarray, norm_val_Y: np.ndarray):
        
        end_train = self.interp_horizon*self.params['delta_t']
        end_val = self.val_horizon*self.params['delta_t']
        timesteps_train = torch.arange(0,end_train,self.params['delta_t'])
        timesteps_val = torch.arange(0,end_val,self.params['delta_t'])
        torch.randperm(100)[0:50].sort()[0]

        return PICalibData(norm_train_X=jnp.array(norm_train_X[:,:self.train_horizon,:]),
                            norm_train_Y=jnp.array(norm_train_Y[:,:self.train_horizon,:]),
                            norm_train_X_=jnp.array(norm_train_X[:,:self.interp_horizon,:]),
                            norm_train_Y_=jnp.array(norm_train_Y[:,:self.interp_horizon,:]),
                            norm_val_X=jnp.array(norm_val_X[:,:self.val_horizon]),
                            norm_val_Y=jnp.array(norm_val_Y[:,:self.val_horizon]),
                            timesteps_train=jnp.array(timesteps_train[:self.train_horizon]),
                            timesteps_val=jnp.array(timesteps_val),
                            timesteps_train_=jnp.array(timesteps_train))
        
   

    
 
