from dataclasses import dataclass
import torch 
import numpy as np
from utils import MeanStdevFilter, prepare_data, PICalibData
from typing import List, Tuple, Dict
import pickle 
import jax
import jax.numpy as jnp
import jax.random as jr
import sys

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class DataProcessor:
    def __init__(self, params):
        self.params: Dict = params
        self.input_filter: MeanStdevFilter = None
        self.input_dim: int = None
        self.output_dim: int = None
        self.output_filter: MeanStdevFilter = None
        self.train_horizon: int = self.params['train_horizon']
        self.val_horizon: int = self.params['val_horizon']
        self.interp_horizon: int = self.params['interp_horizon']

    def get_data(self, load_model) -> PICalibData:
        # read data
        data_dicts = pickle.load(open(f"{self.params['dataset_name']}.pkl", 'rb'))
        if self.params['ode_name'] == 'lorenz':
            #data = np.concatenate([np.concatenate([data_dict["x"], data_dict['y'], data_dict['z'],data_dict["xdot"], data_dict['ydot'], data_dict['zdot']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
            data = np.concatenate([np.concatenate([data_dict["x"], data_dict['y'], data_dict['z']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
        elif self.params['ode_name'] == 'glycolytic':    
            #data = np.concatenate([np.concatenate([data_dict["S1"], data_dict['S2'], data_dict['S3'],data_dict["S4"], data_dict['S5'], data_dict['S6'], data_dict['S7'], data_dict["dS1dt"], data_dict['dS2dt'], data_dict['dS3dt'],data_dict["dS4dt"], data_dict['dS5dt'], data_dict['dS6dt'], data_dict['dS7dt']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
            data = np.concatenate([np.concatenate([data_dict["S1"], data_dict['S2'], data_dict['S3'],data_dict["S4"], data_dict['S5'], data_dict['S6'], data_dict['S7']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
        elif self.params['ode_name'] == 'LVolt': 
            data = np.concatenate([np.concatenate([data_dict["x"], data_dict['y'],data_dict["xdot"], data_dict['ydot']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)
            #data = np.concatenate([np.concatenate([data_dict["x"], data_dict['y']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)  
        elif self.params['ode_name'] == 'lorenz96': 
            data = np.concatenate([np.concatenate([data_dict["X1"], data_dict['X2'],data_dict["X3"], data_dict['X4'], data_dict['X5']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0) 
        elif self.params['ode_name'] == 'FHNag':
            data = np.concatenate([np.concatenate([data_dict["v"], data_dict['w']], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)  
        elif self.params['ode_name'] == 'hopper':
            data = np.concatenate([np.concatenate([data_dict[f"S{dim}"] for dim in range(11)], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)   
        elif self.params['ode_name'] in ['walker2d', 'halfcheetah']:
            data = np.concatenate([np.concatenate([data_dict[f"S{dim}"] for dim in range(17)], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)   
        elif self.params['ode_name'] in ['pen_expert']:
            data = np.concatenate([np.concatenate([data_dict[f"S{dim}"] for dim in range(45)], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)                         
        elif self.params['ode_name'] in ['hammer']:
            data = np.concatenate([np.concatenate([data_dict[f"S{dim}"] for dim in range(46)], axis=-1)[np.newaxis,:,:] for data_dict in data_dicts], axis=0)                                 

        # input and output dim
        self.input_dim = data.shape[-1]
        self.output_dim = data.shape[-1]
        self.params['input_dim'] = self.input_dim
        self.params['output_dim'] = self.output_dim

        # prepare deltas for training
        X_prev = data[:,:-1,:]
        X_next = data[:,1:,:]

        # include 80% of trajectories for training
        total_traj = data.shape[0]
        train_sz = int(total_traj * self.params["train_val_ratio"])
        val_sz = total_traj - train_sz
        rand_perm = torch.randperm(total_traj)

        train_idx = rand_perm[:train_sz]
        val_idx = rand_perm[train_sz:]

        # train and val data
        train_X_traj = X_prev[train_idx,:self.val_horizon,:].copy()
        train_Y_traj = X_next[train_idx,:self.val_horizon,:].copy()

        val_X_traj = X_prev[val_idx,:self.val_horizon,:].copy()
        val_Y_traj = X_next[val_idx,:self.val_horizon,:].copy()

        train_X = train_X_traj.reshape(-1, self.input_dim).copy()
        #val_X = val_X_traj.reshape(-1, self.input_dim).copy()

        # for Variance Normalized loss
        #self.params['norm_loss_wghts'] = torch.FloatTensor(1 / np.var(train_delta, axis=0)).to(device) 

        if not load_model:
            self.input_filter = MeanStdevFilter(self.input_dim) 

            self.calculate_mean_var(train_X)
            norm_train_X, norm_train_Y, norm_val_X, norm_val_Y = \
                    self.prepare_datapoints(train_X_traj, train_Y_traj, val_X_traj, val_Y_traj)
            dataset = self.prepare_dataclass(norm_train_X, norm_train_Y, norm_val_X, norm_val_Y)
            
            #data_dict = {"train_ys": dataset.norm_train_X, "train_ts": dataset.timesteps_train,
            #            "val_ys": dataset.norm_val_X, "val_ts": dataset.timesteps_val}
            #ys = jnp.concatenate([dataset.norm_train_X, dataset.norm_val_X], axis=0)
            #ts = jnp.concatenate([dataset.timesteps_train, dataset.timesteps_val], axis=0)
            #data_dict = {"ys": ys, "ts": ts}            
            #with open("./lorenz_data.pkl", "wb") as f:
            #    pickle.dump(data_dict, f)            
            #sys.exit() 
            
            return dataset
        else:
            return None
    
    def prepare_datapoints(self, train_X_traj: np.ndarray, train_Y_traj: np.ndarray,\
                                val_X_traj: np.ndarray, val_Y_traj: np.ndarray):
        
        seq_len_train = train_X_traj.shape[1]
        seq_len_val = val_X_traj.shape[1]

        train_X = train_X_traj.reshape(-1, self.input_dim)
        train_Y = train_Y_traj.reshape(-1, self.input_dim)
        val_X = val_X_traj.reshape(-1, self.input_dim)
        val_Y = val_Y_traj.reshape(-1, self.input_dim)
        
        norm_train_X = prepare_data(train_X, self.input_filter)
        norm_train_Y = prepare_data(train_Y, self.input_filter)  
        norm_val_X = prepare_data(val_X, self.input_filter)
        norm_val_Y = prepare_data(val_Y, self.input_filter)  

        norm_train_X = norm_train_X.reshape(-1,seq_len_train,self.input_dim)
        norm_train_Y = norm_train_Y.reshape(-1,seq_len_train,self.input_dim)
        norm_val_X = norm_val_X.reshape(-1,seq_len_val,self.input_dim)
        norm_val_Y = norm_val_Y.reshape(-1,seq_len_val,self.input_dim)

        return norm_train_X, norm_train_Y, norm_val_X, norm_val_Y
       
    
    def calculate_mean_var(self, input_data: np.ndarray) -> None:

        total_points = input_data.shape[0]

        for i in range(total_points):
            self.input_filter.update(input_data[i,:])
        
        self.params['input_filter'] = self.input_filter
        #for i in range(total_points):
        #    self.output_filter.update(output_data[i,:])            

        return    
    
    def prepare_dataclass(self, norm_train_X: np.ndarray, norm_train_Y: np.ndarray,\
                           norm_val_X: np.ndarray, norm_val_Y: np.ndarray):
        
        end_train = self.interp_horizon*self.params['delta_t']
        end_val = self.val_horizon*self.params['delta_t']
        timesteps_train = torch.arange(0,end_train,self.params['delta_t'])
        timesteps_val = torch.arange(0,end_val,self.params['delta_t'])
        self.params['tmax'] = end_val # used in pad_input method of ContiFormer class

        return PICalibData(norm_train_X=torch.tensor(norm_train_X[:,:self.train_horizon,:]),
                            norm_train_Y=torch.tensor(norm_train_X[:,:self.val_horizon,:]),
                            norm_train_X_=torch.tensor(norm_train_X[:,:self.interp_horizon,:]),
                            norm_train_Y_=torch.tensor(norm_train_X[:,:self.interp_horizon,:]),
                            norm_val_X=torch.tensor(norm_val_X[:,:self.train_horizon]),
                            norm_val_Y=torch.tensor(norm_val_X[:,:self.val_horizon]),
                            timesteps_train=np.array(timesteps_train[:self.train_horizon]),
                            timesteps_val=np.array(timesteps_val),
                            timesteps_train_=np.array(timesteps_train))
        
   

    
 
