import jax
import diffrax
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr
import equinox as eqx
import optax
import numpy as np
from typing import Tuple, Dict, List, Callable, Union
from diffrax import ODETerm, diffeqsolve, Tsit5, SaveAt
from data_classes import PICalibData
from calib_calc import z_scores, select_quantiles, ensemble_pred, save_scores
import sys
import pickle
import copy
from utils import check_or_make_folder, GaussianMSELoss

# ---------------------------
# Model Components
# ---------------------------

class CDEFunc(eqx.Module):
    mlp: eqx.nn.MLP
    input_size: int
    hidden_size: int

    def __init__(self, input_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size * input_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            # Note the use of a tanh final activation function. This is important to
            # stop the model blowing up. (Just like how GRUs and LSTMs constrain the
            # rate of change of their hidden states.)
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y).reshape(self.hidden_size, self.input_size)

class Decoder(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, hidden_size, output_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=output_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.relu,
            key=key,
        )

    def __call__(self, zx):
        return self.mlp(zx)

class BayesianNNDecoder(eqx.Module):
    layers: List[eqx.nn.Linear]
    delta: eqx.nn.Linear
    logvar: eqx.nn.Linear
    activation: Callable 
    max_logvar: jnp.ndarray
    min_logvar: jnp.ndarray

    def __init__(self, input_dim: int, output_dim: int, width_size, depth, *, key):
        hidden_sizes = [width_size for _ in range(depth)]
        keys = jr.split(key, len(hidden_sizes) + 2)
        self.activation = jnn.silu

        self.layers = []
        in_size = input_dim
        for i, h in enumerate(hidden_sizes):
            layer = eqx.nn.Linear(in_size, h, key=keys[i])
            self.layers.append(layer)
            in_size = h

        self.delta = eqx.nn.Linear(in_size, output_dim, key=keys[-2])
        self.logvar = eqx.nn.Linear(in_size, output_dim, key=keys[-1])

        self.max_logvar = jnp.ones((output_dim)) * 0.5  # Learnable
        self.min_logvar = jnp.ones((output_dim)) * -2.0  # Learnable


    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for layer in self.layers:
            x = self.activation(layer(x))
        delta = self.delta(x)
        logvar = self.logvar(x)
        logvar = self.max_logvar - jnn.softplus(self.max_logvar - logvar)
        logvar = self.min_logvar + jnn.softplus(logvar - self.min_logvar)

        return jnp.concatenate([delta, logvar], axis=-1)

class Initial(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, input_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=input_size,
            out_size=hidden_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.relu,
            key=key,
        )

    def __call__(self, z):
        return self.mlp(z)    

class NeuralCDE(eqx.Module):
    initial: Initial
    func: CDEFunc
    decoder: Union[Decoder,BayesianNNDecoder]
    params: Dict
    
    def __init__(self, input_size: int, hidden_size: int, \
                 output_size: int, params: Dict, key):
        ikey, fkey, dkey = jr.split(key, 3)
        initial_width = 128
        initial_depth = 1
        cde_width = params['cde_nodes']
        cde_depth = params['cde_layers']
        dec_width = params['decoder_nodes']
        dec_depth = params['decoder_layers']
        self.params = params
        self.initial = Initial(self.params['input_dim'], hidden_size, initial_width, initial_depth, key=ikey) # input_size
        self.func = CDEFunc(input_size, hidden_size, cde_width, cde_depth, key=fkey)
        if params['bayesian']:
            self.decoder = BayesianNNDecoder(hidden_size, output_size, dec_width, dec_depth, key=dkey)
        else:    
            self.decoder = Decoder(hidden_size, output_size, dec_width, dec_depth, key=dkey)
        
    def __call__(self, ts, coeffs, X_ctx, evolving_out=True):
        # Each sample of data consists of some timestamps `ts`, and some `coeffs`
        # parameterising a control path. These are used to produce a continuous-time
        # input path `control`.
        
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = None
        
        #y0 = self.initial(control.evaluate(ts[0])) 
        # give initial condition explicitly or probably a transformation of that
        #y0 = jnp.array([0,0,0,0,0,0,0,0])
        y0 = self.initial(X_ctx[0,:self.params['input_dim']])
        
        if evolving_out:
            saveat = diffrax.SaveAt(ts=ts)
        else:
            saveat = diffrax.SaveAt(t1=True)
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=saveat,
        )
        
        if evolving_out:
            prediction = jax.vmap(self.decoder)(solution.ys)
        else:
            pass
        return prediction

# ---------------------------
# Training Infrastructure
# ---------------------------
class DebugExit(Exception): pass

def trainer(params: Dict, calib_data: Tuple[PICalibData], train: bool):

    model = NeuralCDE(input_size=calib_data[0].X_ctx.shape[-1], # 
                    hidden_size=params['hidden_channels'],
                    output_size=calib_data[2].error.shape[-1],
                    params=params,
                    key=jr.PRNGKey(params['seed']+14))
    optim = optax.adam(params['lr'])
    loss_fn = GaussianMSELoss()
    output_filter = params['output_filter']

    @eqx.filter_jit   
    def loss(model, X_ctx, mask, ts, targets, coeffs):
        #preds = jax.vmap(lambda c, t: model(c, t, evolving_out=True))(coeffs, ts)
        #jax.debug.callback(debug_fn, coeffs[1])
        preds = jax.vmap(model)(ts, coeffs, X_ctx)
        aux = (preds, targets)     
        #preds = preds*mask[:,:preds.shape[1],:]
        #targets = targets*mask[:,:targets.shape[1],:3]
        if params['bayesian']:
            mu, logvar = preds[:,:,:3], preds[:,:,3:]
            return (loss_fn(preds, targets), aux) # +0.01*jnp.mean(logvar)
        else:
            return (jnp.mean((preds - targets)**2), aux)

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
  

    @eqx.filter_jit
    def make_step(model, batch, opt_state):
        
        X_ctx, mask, traj_len, train_ts, errors, *X_ctx_sim_coeffs = batch # (32,); (32,100); (32,); (32,99,3)*4
        
        ##### Compute gradients
        (losses, _), grads = grad_loss(model, X_ctx, mask, train_ts, errors, X_ctx_sim_coeffs)

        #### Update model params
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return losses, model, opt_state

    def data_loaders(calib_data: Tuple[PICalibData]):

        train_ts = jnp.repeat(calib_data[2].timesteps[None, :],\
                               repeats=calib_data[2].X.shape[0], axis=0) # (400, 298)
        val_ts = jnp.repeat(calib_data[3].timesteps[None, :],\
                             repeats=calib_data[3].X.shape[0], axis=0)

        train_loader = create_data_loader((calib_data[2].X_ctx,
                                        calib_data[0].mask,
                                        calib_data[0].Traj_len,
                                        train_ts,
                                        calib_data[2].error) +
                                        calib_data[2].X_ctx_coeffs, # (400, 297, 7)*4
                                        batch_size=32, 
                                        key=jr.PRNGKey(0))
        val_loader = create_data_loader((calib_data[1].Y, # ground truth
                                        calib_data[3].Y, # prediction from predictor
                                        calib_data[3].X_ctx,
                                        calib_data[1].mask,
                                        calib_data[1].Traj_len,
                                        val_ts,
                                        calib_data[3].error) +
                                        calib_data[3].X_ctx_coeffs, 
                                        batch_size=32, 
                                        key=jr.PRNGKey(1))
        
        return train_loader, val_loader

    def create_data_loader(data_tuple, batch_size, key):

        dataset_size = data_tuple[2].shape[0]
        # check if all have the same size
        assert all(data.shape[0] == dataset_size for data in data_tuple)
        indices = jnp.arange(dataset_size)

        perm = jr.permutation(key, indices)
        (key,) = jr.split(key,1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(data[batch_perm] for data in data_tuple)
            start = end
            end = start + batch_size    

    best_loss = 897645673738            
    best_model = None
    if train:
        # get data_loaders
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

        for epoch in range(params['epochs']):
            train_loader, val_loader = data_loaders(calib_data)
            train_losses = []
            
            for batch in train_loader:
                
                #X_ctx, mask, traj_len, train_ts, errors, *X_ctx_sim_coeffs = batch
                
                #if jnp.array([traj_len<418]).any():
                #    print(X_ctx_sim_coeffs[0][9,198:,0])
                #    print(mask[9,198:,0])
                #    print(errors[9,198:,0])
                #    sys.exit() 
                train_loss, model, opt_state = make_step(model, batch, opt_state)
                
                train_losses.append(train_loss)

            train_loss = jnp.mean(jnp.array(train_losses))

            val_losses = []

            if epoch % 5 == 0:
                for batch in val_loader:

                    _, _, X_ctx, mask, traj_len, train_ts, errors, *X_ctx_sim_coeffs = batch
                    val_loss, _ = loss(model, X_ctx, mask, train_ts, errors, X_ctx_sim_coeffs)     
                    val_losses.append(val_loss)   

                val_loss = jnp.mean(jnp.array(val_losses))   
                if float(val_loss) < best_loss:
                    print("Better model found...!!")
                    best_loss = float(val_loss)
                    best_model = copy.deepcopy(model)

            print(f"Epoch: {epoch} | Train loss: {train_loss:.2f} | Val loss: {val_loss:.2f}")    

        # saving model
        folder = f"./seed{params['seed']}_{params['dataset_name']}_{params['epochs']}_{params['lr']}_({params['cde_nodes']}_{params['cde_layers']})cde_({params['decoder_nodes']}_{params['decoder_layers']})decoder_bayesian{params['bayesian']}"
        check_or_make_folder(folder)
        eqx.tree_serialise_leaves(folder + "/model.eqx", best_model)    
        data_dict = {"calib_data": calib_data}
        with open(folder + "/data_dict.pkl", "wb") as f:
            pickle.dump(data_dict, f)
    
    else:
        pred_mu, pred_logvar = [], [] # different models' means and variances
        for seed in range(8,9):
            print(f"Inference with Model with SEED {seed}")
            folder = f"./seed{seed}_{params['dataset_name']}_{params['epochs']}_{params['lr']}_({params['cde_nodes']}_{params['cde_layers']})cde_({params['decoder_nodes']}_{params['decoder_layers']})decoder_bayesian{params['bayesian']}"
            model = eqx.tree_deserialise_leaves(folder + "/model.eqx", model)

            with open(folder + "/data_dict.pkl", "rb") as f:
                data_dict = pickle.load(f)
            calib_data = data_dict['calib_data']
            train_loader, val_loader = data_loaders(calib_data)

            for batch_idx, batch in enumerate(val_loader):

                y_true, y_pred, X_ctx, mask, traj_len, train_ts, errors, *X_ctx_sim_coeffs = batch
                _, aux = loss(model, X_ctx, mask, train_ts, errors, X_ctx_sim_coeffs)  
                preds, targets = aux
                pred_error, logvar_error = jnp.split(preds, 2, axis=-1)
                
                pred_mu.append(pred_error)
                pred_logvar.append(logvar_error)
                
                pred_error_unnorm = output_filter.invert(pred_error)
                mse_before = jnp.mean((y_true-y_pred)**2)
                mse_after = jnp.mean((y_true-(y_pred+pred_error_unnorm))**2)

                print(f"MSE before: {mse_before} & MSE After: {mse_after}")
                if batch_idx == 0:
                    break

        pred_mu = jnp.array(pred_mu)
        pred_logvar = jnp.array(pred_logvar)

        ens_mu, ens_var = ensemble_pred(pred_mu, pred_logvar)
        lower_z, upper_z = z_scores()
        lower_quant, upper_quant, mu_pi = select_quantiles(ens_mu, ens_var, lower_z, upper_z) 

        scores = save_scores(lower_quant, upper_quant, targets)

        pred_dict = {"quants": np.array(mu_pi), "targets": np.array(targets), "scores": scores}

        with open(folder + "/pred_dict.pkl", "wb") as f:
            pickle.dump(pred_dict, f)        
                