import jax
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr
import equinox as eqx
import optax
import numpy as np
from typing import Tuple, Dict, List, Callable, Union
from diffrax import ODETerm, diffeqsolve, Tsit5, SaveAt, PIDController
from prepare_data import PICalibData
import sys
import pickle
import copy
from utils import check_or_make_folder, GaussianMSELoss, MSELoss, MeanStdevFilter
import pandas as pd
import os

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func
    params: Dict

    def __init__(self, data_size, width_size, depth, params, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)
        self.params = params

    def __call__(self, ts, y0):
        solution = diffeqsolve(
            ODETerm(self.func),
            Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=self.params['delta_t'],
            y0=y0,
            stepsize_controller=PIDController(rtol=1e-3, atol=1e-6),
            saveat=SaveAt(ts=ts),
        )
        return solution.ys     

def trainer(params: Dict, dataset: PICalibData, train: bool):

    model = NeuralODE(data_size=params['input_dim'], # 
                    width_size=params['num_nodes'],
                    depth=params['num_layers'],
                    params=params,
                    key=jr.PRNGKey(params['seed']+14))
    optim = optax.adam(params['lr'])
    loss_fn = MSELoss()
    input_filter: MeanStdevFilter = params['input_filter']

    folder = f"./seed{params['seed']}_{params['dataset_name']}_{params['epochs']}epochs_{params['lr']}lr_{params['num_nodes']}nodes_{params['num_layers']}layers_{params['train_horizon']}trHz_{params['val_horizon']}valHz_{params['interp_horizon']}intHz_{params['batch_size']}bs_{params['delta_t']}delT_{params['n_sample']}nsample"

    @eqx.filter_jit   
    def loss(model, ti, yi, train: bool=True):

        preds = jax.vmap(model, in_axes=(0, 0))(ti, yi[:,0,:]) # pass initial condition
         
        if train:
            aux = (preds, yi) 
            return (loss_fn(preds, yi), aux)
        else:
            aux = (preds, yi, loss_fn(preds[:,:params['train_horizon'],:], yi[:,:params['train_horizon'],:])) 
            return (loss_fn(preds, yi), aux)

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
  

    @eqx.filter_jit
    def make_step(model, ti, yi, opt_state):
    
        
        ##### Compute gradients
        (losses, _), grads = grad_loss(model, ti, yi)

        #### Update model params
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return losses, model, opt_state

    def data_loaders(dataset: PICalibData, seed: int):

        train_ts = jnp.repeat(dataset.timesteps_train[None, :],\
                               repeats=dataset.norm_train_X.shape[0], axis=0) # (400, 298) train_horizon
        val_ts = jnp.repeat(dataset.timesteps_val[None, :],\
                             repeats=dataset.norm_val_X.shape[0], axis=0)  
        train_ts_ = jnp.repeat(dataset.timesteps_train_[None, :],\
                               repeats=dataset.norm_train_X_.shape[0], axis=0) # interpolation             

        train_loader = create_data_loader((train_ts,
                                           train_ts_,
                                        dataset.norm_train_X,
                                        dataset.norm_train_X_), # (400, 297, 7)*4
                                        batch_size=params['batch_size'], 
                                        key=jr.PRNGKey(seed))
        val_loader = create_data_loader((val_ts,
                                        dataset.norm_val_X), 
                                        batch_size=params['batch_size'], 
                                        key=jr.PRNGKey(seed))
        
        return train_loader, val_loader

    def create_data_loader(data_tuple, batch_size, key):

        dataset_size = data_tuple[1].shape[0]
        # check if all have the same size

        assert all(data.shape[0] == dataset_size for data in data_tuple)
        indices = jnp.arange(dataset_size)

        perm = jr.permutation(key, indices)
        (key,) = jr.split(key,1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(data[batch_perm] for data in data_tuple)
            start = end
            end = start + batch_size   

    if params['ode_name'] == 'lorenz':
        pred_cols = ['x_pred','y_pred','z_pred']
        gr_cols = ['x_gr','y_gr','z_gr']
    elif params['ode_name'] == 'glycolytic':    
        pred_cols = ['S1_pred','S2_pred','S3_pred','S4_pred','S5_pred','S6_pred','S7_pred']
        gr_cols = ['S1_gr','S2_gr','S3_gr','S4_gr','S5_gr','S6_gr','S7_gr']      
    elif params['ode_name'] == 'LVolt':    
        pred_cols = ['x_pred','y_pred','x_dot_pred','y_dot_pred']
        gr_cols = ['x_gr','y_gr','x_dot_gr','y_dot_gr']    
    elif params['ode_name'] == 'lorenz96':
        pred_cols = ['x1_pred','x2_pred','x3_pred','x4_pred','x5_pred']
        gr_cols = ['x1_gr','x2_gr','x3_gr','x4_gr','x5_gr']    
    elif params['ode_name'] == 'FHNag':
        pred_cols = ['v_pred','w_pred']
        gr_cols = ['v_gr','w_gr']           

    def save_errors_train(train_loader):
        
        preds_unnorm, targets_unnorm = [], []
        for batch_id, batch in enumerate(train_loader):
            _, ti_, _, yi_ = batch

            _, aux = loss(model, ti_, yi_)  
            preds, targets = aux
            
            pred_unnorm = input_filter.invert(preds)
            target_unnorm = input_filter.invert(targets)

            preds_unnorm.append(pred_unnorm)
            targets_unnorm.append(target_unnorm)

        preds_unnorm = jnp.concatenate(jnp.array(preds_unnorm), axis=0)
        targets_unnorm = jnp.concatenate(jnp.array(targets_unnorm), axis=0)
        inter_hor = preds_unnorm.shape[1]
        no_of_trajs = preds_unnorm.shape[0]

        preds_unnorm = np.array(preds_unnorm.reshape(-1,params['input_dim']))
        targets_unnorm = np.array(targets_unnorm.reshape(-1,params['input_dim']))
        
        # [0,1,2,...,inter_hor,0,1,2,...,inter_hor,.................,0,1,2,...,inter_hor] for all trajs
        horizon = jnp.stack([jnp.linspace(0,inter_hor-1,inter_hor)]*no_of_trajs)

        df = pd.DataFrame({
                    'horizon': horizon.reshape(-1,),
                    pred_cols[0]: preds_unnorm[:,0],
                    pred_cols[1]: preds_unnorm[:,1],
                    pred_cols[2]: preds_unnorm[:,2],
                    #pred_cols[3]: preds_unnorm[:,3],
                    #pred_cols[4]: preds_unnorm[:,4],
                    #pred_cols[5]: preds_unnorm[:,5], 
                    #pred_cols[6]: preds_unnorm[:,6],                   
                    gr_cols[0]: targets_unnorm[:,0],
                    gr_cols[1]: targets_unnorm[:,1],
                    gr_cols[2]: targets_unnorm[:,2],
                    #gr_cols[3]: targets_unnorm[:,3],
                    #gr_cols[4]: targets_unnorm[:,4],
                    #gr_cols[5]: targets_unnorm[:,5],  
                    #gr_cols[6]: targets_unnorm[:,6]                  
                })

        if os.path.exists(folder + f"/{params['dataset_name']}_errors_0Iter_train.csv"):
            print("Error file already exists!")
        else:    
            print('Error file Saved')
            df.to_csv(folder + f"/{params['dataset_name']}_errors_0Iter_train.csv", index=False)        

    best_loss = 897645673738            
    best_model = None
    if train:
        # get data_loaders
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

        #train_loader, val_loader = data_loaders(dataset, 0) # shuffle these datasets in each loop
        for epoch in range(params['epochs']):  
            train_loader, val_loader = data_loaders(dataset, seed=epoch)
            train_losses = []
            
            for batch_id, batch in enumerate(train_loader):
                
                key = jr.PRNGKey(epoch+batch_id)
                ti, _, yi, _ = batch # (bs,50); (bs,50,feat)
                
                tr_steps = yi.shape[1]
                idx = jnp.sort(jax.random.permutation(key, tr_steps)[:int(params['n_sample']*tr_steps)])
                ti, yi = ti[:,idx], yi[:,idx]

                train_loss, model, opt_state = make_step(model, ti, yi, opt_state)
                
                train_losses.append(train_loss)

            train_loss = jnp.mean(jnp.array(train_losses))

            val_losses_inter, val_losses_ext = [], [] # interpolation and extrapolation losses 

            if epoch % 2 == 0:
                for batch in val_loader:

                    ti, yi = batch

                    val_loss_ext, aux = loss(model, ti, yi, train=False)     
                    val_losses_inter.append(aux[2])   
                    val_losses_ext.append(val_loss_ext)

                avg_val_losses_inter = jnp.mean(jnp.array(val_losses_inter))  
                avg_val_losses_ext = jnp.mean(jnp.array(val_losses_ext))  

                if float(avg_val_losses_ext) < best_loss:
                    print("Better model found...!!")
                    best_loss = float(avg_val_losses_ext)
                    best_model = copy.deepcopy(model)

            print(f"Epoch: {epoch} | Train loss: {train_loss:.2f} | Val loss int: {avg_val_losses_inter:.2f} | Val loss ext: {avg_val_losses_ext:.2f}")    

        # saving model
        check_or_make_folder(folder)
        eqx.tree_serialise_leaves(folder + "/model.eqx", best_model)    
        data_dict = {"dataset": dataset, "params": params}
        with open(folder + "/data_dict.pkl", "wb") as f:
            pickle.dump(data_dict, f)
    
    else:
        preds_unnorm = [] # different models' means and variances
        targets_unnorm = []
        for seed in range(8,9):
            print(f"Inference with Model with SEED {seed}")
            model = eqx.tree_deserialise_leaves(folder + "/model.eqx", model)

            with open(folder + "/data_dict.pkl", "rb") as f:
                data_dict = pickle.load(f)
            dataset = data_dict['dataset']
            # data loaders at the same seed 0 for ease of visualization
            train_loader, val_loader = data_loaders(dataset, seed=0) 

            val_losses_inter, val_losses_ext = [], []
            for batch_idx, batch in enumerate(val_loader):

                ti, yi = batch
                _, aux = loss(model, ti, yi, train=False)  
                preds, targets, _ = aux
                
                pred_unnorm = input_filter.invert(preds)
                target_unnorm = input_filter.invert(targets)
                #pred_unnorm = preds
                #target_unnorm = targets           

                val_loss_inter = jnp.mean((target_unnorm[:,:params['train_horizon'],:]-pred_unnorm[:,:params['train_horizon'],:])**2)
                val_loss_ext = jnp.mean((target_unnorm[:,params['train_horizon']:,:]-pred_unnorm[:,params['train_horizon']:,:])**2)
                val_losses_inter.append(val_loss_inter)
                val_losses_ext.append(val_loss_ext)
                
                preds_unnorm.append(pred_unnorm)
                targets_unnorm.append(target_unnorm)

                #if batch_idx == 0:
                #    break

            mse_inter = jnp.mean(jnp.array(val_losses_inter))
            mse_ext = jnp.mean(jnp.array(val_losses_ext))        
            print(f"MSE Interpolation {mse_inter} | MSE Extrapolation {mse_ext}")    

        ##################### SAVING TRAIN ERRORS #####################
        save_errors_train(train_loader=train_loader)    

        preds_unnorm = jnp.array(preds_unnorm)
        targets_unnorm = jnp.array(targets_unnorm)
        preds_unnorm = jnp.concatenate(preds_unnorm, axis=0)
        targets_unnorm = jnp.concatenate(targets_unnorm, axis=0)

        ##################### SAVING for Viz #####################
        pred_dict = {"pred": preds_unnorm, "targets": targets_unnorm,\
                      "mse_inter": mse_inter, "mse_ext": mse_ext}

        with open(folder + "/pred_dict.pkl", "wb") as f:
            pickle.dump(pred_dict, f)  

        ##################### SAVING VAL ERRORS #####################
        val_hor = preds_unnorm.shape[1]
        no_of_trajs = preds_unnorm.shape[0]

        preds_unnorm = np.array(preds_unnorm.reshape(-1,params['input_dim']))
        targets_unnorm = np.array(targets_unnorm.reshape(-1,params['input_dim']))
        
        # [0,1,2,...,val_hor,0,1,2,...,val_hor,.................,0,1,2,...,val_hor] for all trajs
        horizon = jnp.stack([jnp.linspace(0,val_hor-1,val_hor)]*no_of_trajs)

        df = pd.DataFrame({
                    'horizon': horizon.reshape(-1,),
                    pred_cols[0]: preds_unnorm[:,0],
                    pred_cols[1]: preds_unnorm[:,1],
                    pred_cols[2]: preds_unnorm[:,2],
                    #pred_cols[3]: preds_unnorm[:,3],
                    #pred_cols[4]: preds_unnorm[:,4],
                    #pred_cols[5]: preds_unnorm[:,5],
                    #pred_cols[6]: preds_unnorm[:,6],                     
                    gr_cols[0]: targets_unnorm[:,0],
                    gr_cols[1]: targets_unnorm[:,1],
                    gr_cols[2]: targets_unnorm[:,2],
                    #gr_cols[3]: targets_unnorm[:,3],
                    #gr_cols[4]: targets_unnorm[:,4],
                    #gr_cols[5]: targets_unnorm[:,5],
                    #gr_cols[6]: targets_unnorm[:,6],                    
                })

        if os.path.exists(folder + f"/{params['dataset_name']}_errors_0Iter_val.csv"):
            print("Error file already exists!")
        else:    
            print('Error file Saved')
            df.to_csv(folder + f"/{params['dataset_name']}_errors_0Iter_val.csv", index=False)               