import jax
import diffrax
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr
import equinox as eqx
import optax
import numpy as np
from typing import Tuple, Dict, List, Callable, Union
from diffrax import ODETerm, diffeqsolve, Tsit5, SaveAt
from data_classes import PICalibData
from calib_calc import z_scores, select_quantiles, ensemble_pred, save_scores
import sys
import pickle
from tabulate import tabulate
import copy
from utils import check_or_make_folder, GaussianMSELoss

# ---------------------------
# Model Components
# ---------------------------

class TemporalConvEmbed(eqx.Module):
    conv: eqx.nn.Conv

    def __init__(self, in_channels, out_channels, kernel_size, key):
        # Conv1d = Conv with num_spatial_dims=1
        self.conv = eqx.nn.Conv(
            in_channels=in_channels,
            out_channels=out_channels,  # embedding dimension
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,   # SAME padding
            key=key,
            num_spatial_dims=1
        )

    def __call__(self, x):
        # x: (batch, length, in_channels)
        x = jnp.transpose(x, (1,0))        # -> (batch, in_channels, length)
        x = self.conv(x)                       # (batch, out_channels, length)
        x = jax.nn.relu(x)
        x = jnp.max(x, axis=-1)               # global average pooling over time
        return x                               # (batch, out_channels)

class CDEFunc(eqx.Module):
    mlp: eqx.nn.MLP
    input_size: int
    hidden_size: int

    def __init__(self, input_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size * input_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            # Note the use of a tanh final activation function. This is important to
            # stop the model blowing up. (Just like how GRUs and LSTMs constrain the
            # rate of change of their hidden states.)
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y).reshape(self.hidden_size, self.input_size)

class Decoder(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, hidden_size, output_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=output_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.relu,
            key=key,
        )

    def __call__(self, zx):
        return self.mlp(zx)

class BayesianNNDecoder(eqx.Module):
    layers: List[eqx.nn.Linear]
    delta: eqx.nn.Linear
    logvar: eqx.nn.Linear
    activation: Callable 
    max_logvar: jnp.ndarray
    min_logvar: jnp.ndarray

    def __init__(self, input_dim: int, output_dim: int, width_size, depth, *, key):
        hidden_sizes = [width_size for _ in range(depth)]
        keys = jr.split(key, len(hidden_sizes) + 2)
        self.activation = jnn.silu

        self.layers = []
        in_size = input_dim
        for i, h in enumerate(hidden_sizes):
            layer = eqx.nn.Linear(in_size, h, key=keys[i])
            self.layers.append(layer)
            in_size = h

        self.delta = eqx.nn.Linear(in_size, output_dim, key=keys[-2])
        self.logvar = eqx.nn.Linear(in_size, output_dim, key=keys[-1])

        self.max_logvar = jnp.ones((output_dim)) * 0.5  # Learnable
        self.min_logvar = jnp.ones((output_dim)) * -2.0  # Learnable


    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for layer in self.layers:
            x = self.activation(layer(x))
        delta = self.delta(x)
        logvar = self.logvar(x)
        logvar = self.max_logvar - jnn.softplus(self.max_logvar - logvar)
        logvar = self.min_logvar + jnn.softplus(logvar - self.min_logvar)

        return jnp.concatenate([delta, logvar], axis=-1)

class Initial(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, input_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=input_size,
            out_size=hidden_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.relu,
            key=key,
        )

    def __call__(self, z):
        return self.mlp(z)    

class NeuralCDE(eqx.Module):
    initial: Initial
    func: CDEFunc
    decoder: Union[Decoder,BayesianNNDecoder]
    conv1d: TemporalConvEmbed
    params: Dict
    
    def __init__(self, input_size: int, hidden_size: int, \
                 output_size: int, params: Dict, key):
        ikey, fkey, dkey = jr.split(key, 3)
        num_decoders = 1
        initial_width = 128 # 128
        initial_depth = 1 # 1
        cde_width = params['cde_nodes']
        cde_depth = params['cde_layers']
        dec_width = params['decoder_nodes']
        dec_depth = params['decoder_layers']
        self.params = params
        self.initial = Initial(self.params['input_dim'], hidden_size, initial_width, initial_depth, key=ikey) # input_size
        self.func = CDEFunc(input_size, hidden_size, cde_width, cde_depth, key=fkey)
        #if params['bayesian']:
        #    self.decoder = BayesianNNDecoder(hidden_size, output_size, dec_width, dec_depth, key=dkey)
        #else:    
        self.decoder = [Decoder(hidden_size, output_size, dec_width, dec_depth, key=dkey) for _ in range(num_decoders)]
        self.conv1d = TemporalConvEmbed(in_channels=17,out_channels=17,kernel_size=3,key=jax.random.PRNGKey(0))
        
    def __call__(self, ts, coeffs, X_ctx, evolving_out=True):
        # Each sample of data consists of some timestamps `ts`, and some `coeffs`
        # parameterising a control path. These are used to produce a continuous-time
        # input path `control`.
        
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = None
        
        #y0 = self.initial(control.evaluate(ts[0])) 
        #y0 = control.evaluate(ts[0])[:self.params['input_dim']] 
        # give initial condition explicitly or probably a transformation of that
        y0 = jnp.array([0 for _ in range(self.params['hidden_channels'])])
        #y0 = X_ctx[0,:self.params['input_dim']]
        #y0 = self.initial(X_ctx[0,:self.params['input_dim']])
        #y0 = self.conv1d(X_ctx[:,:self.params['input_dim']])
        
        if evolving_out:
            saveat = diffrax.SaveAt(ts=ts)
        else:
            saveat = diffrax.SaveAt(t1=True)
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=saveat,
        )
        
        if evolving_out:
            predictions = []
            for decoder in self.decoder:
                prediction = jax.vmap(decoder)(solution.ys)
                predictions.append(prediction[jnp.newaxis,:,:])
            out = jnp.squeeze(jnp.mean(jnp.array(predictions), axis=0), axis=0)
        else:
            pass
        return out, solution.stats, solution.ys

# ---------------------------
# Training Infrastructure
# ---------------------------
class DebugExit(Exception): pass

def trainer(params: Dict, calib_data: Tuple[PICalibData], train: bool):

    model = NeuralCDE(input_size=calib_data[0].X_ctx.shape[-1], # 
                    hidden_size=params['hidden_channels'],
                    output_size=calib_data[2].error.shape[-1],
                    params=params,
                    key=jr.PRNGKey(params['seed']+14))
    optim = optax.adam(params['lr'])
    loss_fn = GaussianMSELoss()
    norm_var = 1 / jnp.array(np.square(params['input_filter'].stdev))
    output_filter = params['output_filter']

    @eqx.filter_jit   
    def loss(model, X_ctx, ts, targets, coeffs, y_true, y_pred):

        preds, stats, ys = jax.vmap(model)(ts, coeffs, X_ctx)
        aux = (preds, targets, stats, ys)     
        #preds = preds*mask[:,:preds.shape[1],:]
        #targets = targets*mask[:,:targets.shape[1],:3]
        #if params['bayesian']:
        #    mu, logvar = preds[:,:,:3], preds[:,:,3:]
        #    return (loss_fn(preds, targets), aux) # +0.01*jnp.mean(logvar)
        #else:
        #unnorm_error = output_filter.invert(preds)
        #corr_pred = y_pred + unnorm_error
        #tr_loss = jnp.mean(norm_var*jnp.mean((corr_pred - y_true)**2, axis=(0,1)))
        tr_loss = jnp.mean((preds - targets)**2)
               
        return (tr_loss, aux)

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
  

    @eqx.filter_jit
    def make_step(model, X_ctx, train_ts, errors, X_ctx_sim_coeffs, y_true, y_pred, opt_state):
        
        #X_ctx, train_ts, errors, *X_ctx_sim_coeffs = batch # (32,); (32,100); (32,); (32,99,3)*4
        
        ##### Compute gradients
        (losses, aux), grads = grad_loss(model, X_ctx, train_ts, errors, X_ctx_sim_coeffs, y_true, y_pred)

        #### Update model params
        preds, _, stats, ys = aux
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return losses, model, opt_state, preds, stats, ys

    def data_loaders(calib_data: Tuple[PICalibData], data_type: str, per_key=None, batch_sz: int=params['batch_size']):

        train_ts = jnp.repeat(calib_data[2].timesteps[None, :],\
                               repeats=calib_data[2].X.shape[0], axis=0) # (400, 298)
        val_ts = jnp.repeat(calib_data[3].timesteps[None, :],\
                             repeats=calib_data[3].X.shape[0], axis=0)

        if data_type == 'Train':
            loader = create_data_loader((calib_data[0].Y, # ground truth
                                            calib_data[2].Y,
                                            calib_data[2].X_ctx,
                                            train_ts,
                                            calib_data[2].error) +
                                            calib_data[2].X_ctx_coeffs, # (400, 297, 7)*4
                                            batch_size=batch_sz, # 15 for one data
                                            key=jr.PRNGKey(per_key))
        else:    
            loader = create_data_loader((calib_data[1].Y, # ground truth
                                            calib_data[3].Y, # prediction from predictor
                                            calib_data[3].X_ctx,
                                            val_ts,
                                            calib_data[3].error) +
                                            calib_data[3].X_ctx_coeffs, 
                                            batch_size=calib_data[3].Y.shape[0], # 15 for one data
                                            key=jr.PRNGKey(1)) # key is not important
        
        return loader

    def create_data_loader(data_tuple, batch_size, key):

        dataset_size = data_tuple[2].shape[0]
        # check if all have the same size
        assert all(data.shape[0] == dataset_size for data in data_tuple)
        indices = jnp.arange(dataset_size)

        perm = jr.permutation(key, indices)
        (key,) = jr.split(key,1)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield tuple(data[batch_perm] for data in data_tuple)
            start = end
            end = start + batch_size    

    best_loss = 897645673738            
    best_model = None
    early_stop_ctr = 0
    if train:
        # get data_loaders
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        history = []
        print(f"The STEER value is: {params['steer']}")
        func_evals = []
       
        for epoch in range(params['epochs']):
            # RANDOMIZE THE DATA LOADER WITH EPOCH AS SEED
            train_loader = data_loaders(calib_data, data_type='Train', per_key=epoch)
            val_loader = data_loaders(calib_data, data_type='Val')

            train_losses, tr_losses_before, tr_losses_after = [], [], []
            func_eval = []
            
            for batch_id, batch in enumerate(train_loader):

                key=jr.PRNGKey(epoch+batch_id)
                #end_T = params['train_horizon']
                end_T = jr.randint(key, shape=(1,), minval=params['train_horizon']-params['steer'], maxval=params['train_horizon'])[0]

                y_true, y_pred, X_ctx, train_ts, errors, *X_ctx_sim_coeffs = batch 

                # (bs, 149, 11); (bs, 149, 11); (bs, 149, 12); (bs, 149); (bs, 149, 11); (bs, 148, 12)*4     
                y_true, y_pred, X_ctx = y_true[:,:end_T], y_pred[:,:end_T], X_ctx[:,:end_T]
                train_ts, errors, X_ctx_sim_coeffs = train_ts[:,:end_T], errors[:,:end_T], (X_ctx_sim_coeffs[0][:,:end_T-1], X_ctx_sim_coeffs[1][:,:end_T-1], X_ctx_sim_coeffs[2][:,:end_T-1], X_ctx_sim_coeffs[3][:,:end_T-1])
                tr_steps = y_true.shape[1]
                
                idx = jnp.sort(jax.random.permutation(key, tr_steps)[:int(params['nsample']*tr_steps)])
                y_true, y_pred, X_ctx, train_ts, errors = \
                                y_true[:,idx], y_pred[:,idx], X_ctx[:,idx], train_ts[:,idx], errors[:,idx]
                X_ctx_sim_coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(train_ts,X_ctx)          
            
                # ys is hidden state dynamics 
                train_loss, model, opt_state, pred_error, stats, ys = make_step(model, X_ctx, train_ts, errors, X_ctx_sim_coeffs, y_true, y_pred, opt_state)
                #print("Function evaluations:", np.mean(np.array(stats['num_steps'])))
                func_eval.append(stats['num_steps'])

                pred_error_unnorm = output_filter.invert(pred_error)
                mse_before_tr = jnp.mean((y_true-y_pred)**2)
                mse_after_tr = jnp.mean((y_true-(y_pred+pred_error_unnorm))**2) 

                train_losses.append(train_loss)
                tr_losses_before.append(mse_before_tr)
                tr_losses_after.append(mse_after_tr)  

                #if batch_id == 5:
                #    break

            train_loss = jnp.mean(jnp.array(train_losses))
            tr_loss_before = jnp.mean(jnp.array(tr_losses_before))
            tr_loss_after = jnp.mean(jnp.array(tr_losses_after))   
            func_eval = jnp.mean(jnp.array(func_eval))
            func_evals.append(func_eval)

            val_losses, loss_before, loss_after = [], [], []
            
            if epoch % 1 == 0:
                is_better = False
                for batch in val_loader:

                    y_true, y_pred, X_ctx, val_ts, errors, *X_ctx_sim_coeffs = batch
                    y_true, y_pred, X_ctx = y_true[:,:params['val_horizon']], y_pred[:,:params['val_horizon']], X_ctx[:,:params['val_horizon']]
                    val_ts, errors, X_ctx_sim_coeffs = val_ts[:,:params['val_horizon']], errors[:,:params['val_horizon']], (X_ctx_sim_coeffs[0][:,:params['val_horizon']-1], X_ctx_sim_coeffs[1][:,:params['val_horizon']-1], X_ctx_sim_coeffs[2][:,:params['val_horizon']-1], X_ctx_sim_coeffs[3][:,:params['val_horizon']-1])
                    val_loss, aux = loss(model, X_ctx, val_ts, errors, X_ctx_sim_coeffs, y_true, y_pred)   
                    val_losses.append(val_loss)  
                    pred_error, _, _, ys = aux # ys is hidden state dynamics
                    pred_error_unnorm = output_filter.invert(pred_error)
                    mse_before = jnp.mean((y_true-y_pred)**2)
                    mse_after = jnp.mean((y_true-(y_pred+pred_error_unnorm))**2) 
                    loss_before.append(mse_before)   
                    loss_after.append(mse_after)        

                val_loss = jnp.mean(jnp.array(val_losses))  
                loss_before_ = jnp.mean(jnp.array(loss_before))
                loss_after_ = jnp.mean(jnp.array(loss_after)) 
                
                if float(val_loss) < best_loss:
                    #print("Better model found...!!")
                    early_stop_ctr = 0
                    is_better = True                    
                    best_loss = float(val_loss)
                    best_model = copy.deepcopy(model)
                else:
                    early_stop_ctr += 1                      

            note = "Better model" if is_better else ""
            history.append([
                epoch, train_loss, val_loss,
                tr_loss_before, tr_loss_after,
                loss_before_, loss_after_, func_eval,
                note
            ])         
            headers = ["Epoch", "Train", "Val", "Train_Before", "Train_After", "Val_Before", "Val_After", "func_eval", "Note"]
            table = tabulate(history, headers=headers, floatfmt=".4f", tablefmt="grid")
            if epoch == 0:
                print(table)
            else:
                print(table.split("\n")[-2]+ "\n" + table.split("\n")[-1])                

            if early_stop_ctr > 2:
                print("Stopping Early.....!!")
                break  
            #print(f"Epoch: {epoch} Train loss: {train_loss:.2f} | Val loss: {val_loss:.2f}")    

        # saving model
        folder = f"./seed{params['seed']}_{params['dataset_name']}_{params['epochs']}epch_{params['lr']}lr_{params['batch_size']}bs_({params['cde_nodes']}_{params['cde_layers']})cde_({params['decoder_nodes']}_{params['decoder_layers']})decoder_{params['hidden_channels']}cha_{params['delta_t']}deltat_{params['n_sample']}n_sample_{params['nsample']}nsample_{params['train_horizon']}trHz_{params['val_horizon']}valHz_{params['steer']}STEER_{params['iter_correction']}Iter"

        check_or_make_folder(folder)
        eqx.tree_serialise_leaves(folder + "/model.eqx", best_model)    
        data_dict = {"calib_data": calib_data}
        nfe_dict = {"nfe_epochs": func_evals}
        with open(folder + "/data_dict.pkl", "wb") as f:
            pickle.dump(data_dict, f)
        with open(folder + "/nfe_dict.pkl", "wb") as f:
            pickle.dump(nfe_dict, f)            
    
    else:
        seed = params['seed']
        print(f"Inference with Model with SEED {seed}")
        folder = f"./seed{params['seed']}_{params['dataset_name']}_{params['epochs']}epch_{params['lr']}lr_{params['batch_size']}bs_({params['cde_nodes']}_{params['cde_layers']})cde_({params['decoder_nodes']}_{params['decoder_layers']})decoder_{params['hidden_channels']}cha_{params['delta_t']}deltat_{params['n_sample']}n_sample_{params['nsample']}nsample_{params['train_horizon']}trHz_{params['val_horizon']}valHz_{params['steer']}STEER_{params['iter_correction']}Iter"

        model = eqx.tree_deserialise_leaves(folder + "/model.eqx", model)

        with open(folder + "/data_dict.pkl", "rb") as f:
            data_dict = pickle.load(f)
        calib_data = data_dict['calib_data']
        train_loader = data_loaders(calib_data, 'Train', per_key=0, batch_sz=calib_data[0].X_ctx.shape[0])

        def save_errors(loader, test_type: str, until: int):
            all_true, pred_preds, corr_pred, true_error, pred_errors = [], [], [], [], []
            mses_before, mses_after, mses_before_interp, mses_after_interp, mses_before_ext, mses_after_ext = [], [], [], [], [], []
            for batch_idx, batch in enumerate(loader):

                y_true, y_pred, X_ctx, train_ts, errors, *X_ctx_sim_coeffs = batch
                _, aux = loss(model, X_ctx, train_ts, errors, X_ctx_sim_coeffs, y_true, y_pred)  
                preds, targets, _, ys = aux # ys is hidden state dynamics
                
                with open(folder + "/hidden_state.pkl", mode='wb') as f:
                    pickle.dump(ys, f)     
                #if params['bayesian']:
                #    pred_error, logvar_error = jnp.split(preds, 2, axis=-1)
                #    pred_mu.append(pred_error)
                #    pred_logvar.append(logvar_error)
                #else:
                pred_error = preds

                pred_error_unnorm = output_filter.invert(pred_error)
                true_error_unnorm = output_filter.invert(errors)
                mse_before = jnp.mean((y_true-y_pred)**2)
                mse_after = jnp.mean((y_true-(y_pred+pred_error_unnorm))**2)

                mses_before.append(mse_before)
                mses_after.append(mse_after)

                all_true.append(y_true)
                corr_pred.append(y_pred+pred_error_unnorm)
                true_error.append(true_error_unnorm)
                pred_errors.append(pred_error_unnorm)
                pred_preds.append(y_pred)

                if test_type == 'Val':

                    y_true_interp = y_true[:,:until,:]
                    y_pred_interp = y_pred[:,:until,:]
                    pred_error_unnorm_interp = pred_error_unnorm[:,:until,:]

                    mse_before_interp = jnp.mean((y_true_interp-y_pred_interp)**2)
                    mse_after_interp = jnp.mean((y_true_interp-(y_pred_interp+pred_error_unnorm_interp))**2)

                    #y_true_ext = y_true[:,params['interp_horizon']:,:]
                    #y_pred_ext = y_pred[:,params['interp_horizon']:,:]
                    #pred_error_unnorm_ext = pred_error_unnorm[:,params['interp_horizon']:,:]

                    #mse_before_ext = jnp.mean((y_true_ext-y_pred_ext)**2)
                    #mse_after_ext = jnp.mean((y_true_ext-(y_pred_ext+pred_error_unnorm_ext))**2)

                    mses_before_interp.append(mse_before_interp)
                    mses_after_interp.append(mse_after_interp)

                    #mses_before_ext.append(mse_before_ext)
                    #mses_after_ext.append(mse_after_ext)

            print("============================================")
            print(f"Printing the results for: {test_type}")    
            #print(f"MSE before: {jnp.mean(jnp.array(mses_before)):.6f} & MSE After: {jnp.mean(jnp.array(mses_after)):.6f}")
            if test_type == 'Val':
               mse_b = jnp.mean(jnp.array(mses_before_interp)) # before
               mse_a = jnp.mean(jnp.array(mses_after_interp)) # after
               perc_red = ((mse_b-mse_a) / mse_b) * 100
               print(f"MSE before {until}: {mse_b:.6f} & MSE After {until}: {mse_a:.6f} & % Reduction: {perc_red}")
            #    print(f"MSE before Ext: {jnp.mean(jnp.array(mses_before_ext)):.6f} & MSE After Ext: {jnp.mean(jnp.array(mses_after_ext)):.6f}")
            print("============================================")

            corr_pred = jnp.array(corr_pred)
            all_true = jnp.array(all_true)
            true_error = jnp.array(true_error)
            pred_errors = jnp.array(pred_errors)
            pred_preds = jnp.array(pred_preds)

            return all_true, corr_pred, true_error, pred_preds, pred_errors, perc_red

        #targets_all_tr, corr_pred_tr, _, _, _ = save_errors(train_loader, 'Train', None)
        perc_reds = []
        untils = [50,60,70,90,100,110,120,130,140,150,180,190,200]
        for until in untils:
            val_loader = data_loaders(calib_data, 'Val', batch_sz=calib_data[1].X_ctx.shape[0])
            gr_pred, corr_pred, true_error, pred_preds, pred_errors, perc_red = save_errors(val_loader, 'Val', until)
            perc_reds.append(perc_red)
            
        pred_dict = {"gr_pred": gr_pred, "corr_pred": corr_pred, "true_error": true_error,
                     "pred_preds": pred_preds, "pred_errors": pred_errors}
        perform_dict = {"untils": np.array(untils), "perc_reds": np.array(perc_reds)}
        """
        else:
            pred_mu = jnp.array(pred_mu)
            pred_logvar = jnp.array(pred_logvar)               

            ens_mu, ens_var = ensemble_pred(pred_mu, pred_logvar)
            lower_z, upper_z = z_scores()
            lower_quant, upper_quant, mu_pi = select_quantiles(ens_mu, ens_var, lower_z, upper_z) 

            scores = save_scores(lower_quant, upper_quant, targets)

            pred_dict = {"quants": np.array(mu_pi), "targets": np.array(targets_all), "scores": scores}
        """    

        with open(folder + "/corr_pred_dict.pkl", "wb") as f:
            pickle.dump(perform_dict, f)        
        with open(folder + "/pred_dict.pkl", "wb") as f:
            pickle.dump(pred_dict, f)               
                