import jax
import diffrax
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr
import equinox as eqx
import optax
import numpy as np
from typing import Tuple, Dict
from calib_calc import generate_pi
from data_classes import PICalibData
import sys
import pickle
from utils import check_or_make_folder

class CDEFunc(eqx.Module):
    mlp: eqx.nn.MLP
    input_size: int
    hidden_size: int

    def __init__(self, input_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size * input_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            # Note the use of a tanh final activation function. This is important to
            # stop the model blowing up. (Just like how GRUs and LSTMs constrain the
            # rate of change of their hidden states.)
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y).reshape(self.hidden_size, self.input_size)
    
class VelocityFieldDecoder(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, hidden_size, output_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=output_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.relu, # experiment with activation functions
            key=key,
        )

    def __call__(self, t_flow, z_t, h_t):
        #if train:
        #    z_t = (1 - t_flow) * eps_flow + t_flow * targets
        #    inp = jnp.concatenate([t_flow, z_t, h_t], axis=-1)
        #else:
            
        inp = jnp.concatenate([t_flow, z_t, h_t], axis=-1)
        return self.mlp(inp)    

class Initial(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, input_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=input_size,
            out_size=hidden_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.relu,
            key=key,
        )

    def __call__(self, z):
        return self.mlp(z)    

class NeuralCDE(eqx.Module):
    initial: Initial
    func: CDEFunc
    decoder: VelocityFieldDecoder
    
    def __init__(self, input_size: int, hidden_size: int, \
                 output_size: int, params: Dict, key):
        ikey, fkey, dkey = jr.split(key, 3)
        initial_width = 128
        initial_depth = 1
        cde_width = params['cde_nodes']
        cde_depth = params['cde_layers']
        dec_width = params['decoder_nodes']
        dec_depth = params['decoder_layers']
        self.initial = Initial(input_size, hidden_size, initial_width, initial_depth, key=ikey)
        self.func = CDEFunc(input_size, hidden_size, cde_width, cde_depth, key=fkey)
        self.decoder = VelocityFieldDecoder(hidden_size+output_size+1, output_size, dec_width, dec_depth, key=dkey)
        
    def __call__(self, ts, coeffs):
        # Each sample of data consists of some timestamps `ts`, and some `coeffs`
        # parameterising a control path. These are used to produce a continuous-time
        # input path `control`.
        
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = None
        
        y0 = self.initial(control.evaluate(ts[0]))
        saveat = diffrax.SaveAt(ts=ts)
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=saveat,
        )
        #if train:
        #    prediction = jax.vmap(self.decoder)(t_flow, eps_flow, targets, solution.ys) # solution.ys: (298,8)

        #    return prediction
        #else:
        return solution.ys

# ---------------------------
# Training Infrastructure
# ---------------------------

def trainer(params: Dict, calib_data: Tuple[PICalibData], train: bool):

    model = NeuralCDE(input_size=calib_data[0].X_ctx.shape[-1],
                            hidden_size=params['hidden_channels'],
                            output_size=calib_data[2].error.shape[-1],
                            params=params,
                            key=jr.PRNGKey(params['seed']+12))
    optim = optax.adam(params['lr'])

    @eqx.filter_jit   
    def loss(model, t_flow, eps_flow, ts, targets, coeffs):
        
        z_t = (1 - t_flow) * eps_flow + t_flow * targets
        h_t = jax.vmap(model)(ts, coeffs)
        # (32,298,k) first vmap over 32, second over 298, the decoder sees k-dim vector
        pred_velocity = jax.vmap(jax.vmap(model.decoder))(t_flow, z_t, h_t)
        
        tar_velocity = targets - eps_flow
        aux = (pred_velocity, tar_velocity)
        return (jnp.mean((pred_velocity - tar_velocity)**2), aux)
    
    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model, batch, opt_state):
        
        t_flow, eps_flow, train_ts, errors, *X_ctx_sim_coeffs = batch # (32,100); (32,); (32,99,3)*4 
        
        ##### Compute gradients
        (losses, _), grads = grad_loss(model, t_flow, eps_flow, train_ts, errors, X_ctx_sim_coeffs)

        #### Update model params
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return losses, model, opt_state

    def data_loaders(calib_data: Tuple[PICalibData]):

        train_ts = jnp.repeat(calib_data[2].timesteps[None, :],\
                               repeats=calib_data[2].X.shape[0], axis=0) # (400, 298)
        val_ts = jnp.repeat(calib_data[3].timesteps[None, :],\
                             repeats=calib_data[3].X.shape[0], axis=0) 

        train_loader = create_data_loader((train_ts,
                                        calib_data[2].error) +
                                        calib_data[2].X_ctx_coeffs, # (800, 297, 7)*4
                                        batch_size=32, 
                                        key=jr.PRNGKey(params['seed']+13))
        val_loader = create_data_loader((calib_data[1].Y, # ground truth
                                        calib_data[3].Y, # prediction from predictor
                                        val_ts,
                                        calib_data[3].error) +
                                        calib_data[3].X_ctx_coeffs, 
                                        batch_size=32, 
                                        key=jr.PRNGKey(params['seed']+15))
        
        return train_loader, val_loader

    def create_data_loader(data_tuple, batch_size, key):

        dataset_size = data_tuple[2].shape[0]
        # check if all have the same size
        assert all(data.shape[0] == dataset_size for data in data_tuple)
        indices = jnp.arange(dataset_size)

        perm = jr.permutation(key, indices)
        (key,) = jr.split(key,1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(data[batch_perm] for data in data_tuple)
            start = end
            end = start + batch_size        
    
    if train:
        # get data_loaders
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

        for epoch in range(params['epochs']):
            train_loader, val_loader = data_loaders(calib_data)
            train_losses = []
            for batch_idx, batch in enumerate(train_loader): 
                # train_ts: (32, 298); error: (32, 298, 6); X_ctx_coeffs: (32, 297, 7)
                #### Generate time interpolation and base distribution samples 
                seed = jax.random.PRNGKey(params['seed']+16)
                key = jax.random.fold_in(seed, epoch)
                key = jax.random.fold_in(key, batch_idx)
                key_t, key_eps = jax.random.split(key)
                key_t, key_eps = jr.split(key)
                bs, seq_len, outs = batch[1].shape[0], batch[1].shape[1], batch[1].shape[2]
                t_flow = jr.uniform(key_t, (bs,seq_len,1))
                eps_flow = jr.uniform(key_eps, (bs,seq_len,outs))
                #### add to the batch tuple
                batch = (t_flow, eps_flow) + batch # tuple has 8 elements

                train_loss, model, opt_state = make_step(model, batch, opt_state)
                train_losses.append(train_loss)

            train_loss = jnp.mean(jnp.array(train_losses))

            val_losses = []

            if epoch % 5 == 0:
                for batch in val_loader:

                    #### Generate time interpolation and base distribution samples 
                    key = jr.PRNGKey(params['seed']+19)
                    key_t, key_eps = jr.split(key)
                    bs, seq_len, outs = batch[1].shape[0], batch[1].shape[1], batch[1].shape[2]
                    t_flow = jr.uniform(key_t, (bs,seq_len,1))
                    eps_flow = jr.uniform(key_eps, (bs,seq_len,outs))

                    _, _, train_ts, errors, *X_ctx_sim_coeffs = batch
                    val_loss, _ = loss(model, t_flow, eps_flow, train_ts, errors, X_ctx_sim_coeffs)     
                    val_losses.append(val_loss)   

                val_loss = jnp.mean(jnp.array(val_losses))    

            print(f"Epoch: {epoch} | Train loss: {train_loss:.2f} | Val loss: {val_loss:.2f}")

        # saving model
        folder = f"./{params['seed']}_{params['dataset_name']}_{params['epochs']}_{params['lr']}_({params['cde_nodes']}_{params['cde_layers']})cde_({params['decoder_nodes']}_{params['decoder_layers']})decoder"
        check_or_make_folder(folder)
        eqx.tree_serialise_leaves(folder + "/model.eqx", model)    
        data_dict = {"calib_data": calib_data}
        with open(folder + "/data_dict.pkl", "wb") as f:
            pickle.dump(data_dict, f)
    
    else:
        folder = f"./{params['seed']}_{params['dataset_name']}_{params['epochs']}_{params['lr']}_({params['cde_nodes']}_{params['cde_layers']})cde_({params['decoder_nodes']}_{params['decoder_layers']})decoder"
        model = eqx.tree_deserialise_leaves(folder + "/model.eqx", model)

        with open(folder + "/data_dict.pkl", "rb") as f:
            data_dict = pickle.load(f)
        calib_data = data_dict['calib_data']
        train_loader, val_loader = data_loaders(calib_data)

        @eqx.filter_jit
        def sample_flow(h_t, errors, key, steps=50):
            
            bs, seq_len, outs = errors.shape[0], errors.shape[1], errors.shape[2]
            key, subkey = jr.split(key)
            z_t = jr.normal(subkey, (bs, seq_len, outs))
            dt = 1.0 / steps
            for i in range(steps):
                t_flow = jnp.full((bs, seq_len, 1), i * dt)
                v = jax.vmap(jax.vmap(model.decoder))(t_flow, z_t, h_t)
                z_t = z_t + v * dt
            return z_t
        
        for batch in val_loader:

            _, _, train_ts, errors, *X_ctx_sim_coeffs = batch
            targets = errors
            h_t = jax.vmap(model)(train_ts, X_ctx_sim_coeffs)
        
            num_samples_desired = 1000
            samp_per_vmap = 100
            all_keys = jr.split(jr.PRNGKey(params['seed']+20), num_samples_desired)
            all_quants = []
            for i in range(5):
                #subkeys = all_keys[i * samp_per_vmap: (i + 1) * samp_per_vmap]
                preds_vmap = jax.vmap(sample_flow, in_axes=(None, None, 0))(h_t, errors, all_keys)
                quant = generate_pi(preds_vmap) 
                all_quants.append(quant)
            pred_dict = {"quant": np.array(quant), "targets": np.array(targets)}
            with open(folder + "/pred_dict.pkl", "wb") as f:
                pickle.dump(pred_dict, f)        
            sys.exit()    