import torch
import torch.nn as nn
import torch.nn.functional as F

import geotorch

import numpy as np
from typing import Tuple
from lib.jax_compat import associative_scan

from model.vit import VisionTransformer, vit_tiny, vit_small, vit_base
from model.utils import forward_fill_3d, fill_zero_padding

from functools import partial

VALID_ENCODERS = ["vit-custom", "vit-tiny", "vit-small", "vit-base"]

def model_factory(encoder_type, position, input_dim=450, **kwargs) :
    
    assert encoder_type in VALID_ENCODERS, f"Unknown encoder : {encoder_type}"
    
    if encoder_type == "vit-custom" :

        embed_dim = kwargs["embed_dim"]
        mlp_ratio = kwargs["mlp_ratio"]
        depth = kwargs["depth"]
        dropout = kwargs.get("drop_out", 0.0)
        n_layer_embedder = kwargs.get("n_layer_embedder", 1)
        mask_fill_mode = kwargs.get("mask_fill_mode", "zero")

        assert (not embed_dim is None) and (not mlp_ratio is None) and (not depth is None), "check your model configs (vit-custom)"

        model = VisionTransformer(
            position=position,
            input_dim=input_dim, 
            embed_dim=embed_dim,
            depth=depth, 
            num_heads=max(int(embed_dim // 64), 1), 
            mlp_ratio=mlp_ratio,
            qkv_bias=True, 
            drop_rate=dropout,
            attn_drop_rate=dropout,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            n_layer_embedder=n_layer_embedder,
            mask_fill_mode=mask_fill_mode)
        
        return model

    elif encoder_type == "vit-tiny" :
        return vit_tiny(position=position, input_dim=input_dim)
    elif encoder_type == "vit-small" :
        return vit_small(position=position, input_dim=input_dim)
    elif encoder_type == "vit-base" :
        return vit_base(position=position, input_dim=input_dim)


def get_state_dim(encoder_type) :
    if "tiny" in encoder_type :
        return 192
    elif "small" in encoder_type :
        return 384
    elif "base" in encoder_type :
        return 768
    elif "brainlm_11m" in encoder_type :
        return 512
    else:
        raise ValueError(f"Unknown encoder : {encoder_type}")

class RNN_layers(nn.Module):
    """
    Optional recurrent layers. This is inspired by the fact that adding
    recurrent layers on top of the Transformer helps language modeling.
    """

    def __init__(self, width, num_layers=1):
        super().__init__()

        self.gru = nn.GRU(width, width, num_layers=num_layers, batch_first=True)
        self.linear = nn.Linear(width, width)

    def forward(self, inp, non_pad_mask):

        lengths = non_pad_mask.long().sum(1).cpu()
        lengths = torch.ones_like(lengths) * inp.size(1)

        pack_enc_output = nn.utils.rnn.pack_padded_sequence(
            inp, lengths, batch_first=True, enforce_sorted=False)
        temp = self.gru(pack_enc_output)[0]
        out = nn.utils.rnn.pad_packed_sequence(temp, batch_first=True)[0]
        out = self.linear(out)
        return out

class JEPAEncoder(nn.Module) : # context & target encoder

    def __init__(self, args) :

        super().__init__()

        mlp_ratio = args.get("mlp_ratio", 4.0)
        
        self.mask_fill_mode = args.get("mask_fill_mode", "zero")
                
        n_layer_temporal_embedder = args.get("n_layer_temporal_embedder", 1)
        n_layer_encoder = args.get("n_layer_encoder") or args.get("n_layer_temporal_encoder")
        self.temporal_encoder = model_factory(args.encoder, 
                                                position="temporal", 
                                                input_dim=args.input_dim, 
                                                depth=n_layer_encoder,
                                                mlp_ratio=mlp_ratio,
                                                embed_dim=args.state_dim,
                                                drop_out=args.drop_out,
                                                n_layer_embedder=n_layer_temporal_embedder,
                                                mask_fill_mode=self.mask_fill_mode
                                                )


    def forward(self, obs, times, temporal_mask=None, spatial_mask=None, history_only=False) :

        x = self.temporal_encoder(obs, times=times, mask=temporal_mask, history_only=history_only)

        if temporal_mask is not None and self.mask_fill_mode == "zero" :
            x = fill_zero_padding(x, temporal_mask)

        return x
    
class JEPAPredictor(nn.Module) :

    def __init__(self, args) :

        super().__init__()
        
        n_layer_rnn = args.get("n_layer_rnn", 1)
        norm = args.get("use_rnn_norm", False)

        state_dim = get_state_dim(args.encoder) if args.state_dim is None else args.state_dim
        self.rnn = RNN_layers(state_dim, num_layers=n_layer_rnn) if n_layer_rnn > 0 else nn.Identity()
        self.norm = nn.LayerNorm(state_dim) if norm else None

        self.dynamics = LinearSDE(args)

    def forward(self, C, times, mask) :
        
        C = forward_fill_3d(C)
        C = self.rnn(C, mask)
        if self.norm :
            C = self.norm(C)

        means, stds, alphas = self.dynamics(C, times)

        return (means, stds), alphas


######################################################################################################
# predictor class : LinearSDE
######################################################################################################

def elup(x: torch.Tensor) -> torch.Tensor:
    return torch.exp(x)

@torch.jit.script
def binary_operator(q_i: Tuple[torch.Tensor, torch.Tensor], q_j: Tuple[torch.Tensor, torch.Tensor]):
    A_i, Bu_i = q_i
    A_j, Bu_j = q_j
    return A_j * A_i, torch.addcmul(Bu_j, A_j, Bu_i)

def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)

def init_orthogonal(m):
    if type(m) == nn.Linear:
        nn.init.orthogonal_(m.weight, 1)        

class LinearSDE(torch.nn.Module):
    def __init__(self, args):
        super(LinearSDE, self).__init__()
        
        state_dim = get_state_dim(args.encoder) if args.state_dim is None else args.state_dim
        
        self.init_sigma = args.init_sigma
        self.ts = args.ts
        self.ld = state_dim
        self.nb = args.num_basis
        
        #### Consturct base matrix for A.
        self.E = nn.Linear(self.ld, self.ld, bias=False)
        self.E.apply(init_orthogonal)
        geotorch.orthogonal(self.E, "weight")
        
        self.D = nn.Parameter(torch.randn(self.nb, self.ld))
        
        ### Init mean and covariance
        self.init_mean = torch.nn.Parameter(torch.randn(self.ld))
        self.init_log_var = torch.nn.Parameter(torch.randn(self.ld))
        self.y_log_var = torch.nn.Parameter(torch.randn(self.ld))
        
        #### Consturct coefficient net for A.
        self.coeff_net = nn.Sequential(nn.Linear(self.ld, self.nb),
                                       nn.Softmax(dim=-1))

        self.B = nn.Linear(self.ld, self.ld, bias=False)
        self.B.apply(init_normal)
        self.M = nn.Linear(self.ld, self.ld, bias=False)
        self.M.apply(init_normal)
        
    def get_matrix(self, alpha, obs_times, sigma=1):    
        
        Identity = torch.ones(alpha.shape[-1], device=alpha.device)
        
        A_basis = - (elup(self.D) + 1e-6)
        
        A_coeff = self.coeff_net(alpha)
        A_mat = (A_coeff[..., None] * A_basis[None]).sum(1)
        
        exp_A_mat_m = torch.exp(A_mat * obs_times)
        exp_B_mat_m = (1/A_mat) * (exp_A_mat_m - Identity) * alpha

        exp_A_mat_v = torch.exp(2 * A_mat * obs_times)
        exp_B_mat_v = 0.5 * sigma**2 * (1/A_mat) * (exp_A_mat_v - Identity) + 1e-6

        return torch.cat([exp_A_mat_m, exp_A_mat_v], dim=-1), torch.cat([exp_B_mat_m, exp_B_mat_v], dim=-1)
    
    def parallel_compute(self, init, E, Z, obs_times):
        
        alphas = torch.vmap(lambda u: self.B(u))(Z)

        mats_A, mats_B = torch.vmap(lambda a, t: self.get_matrix(a, t))(alphas, obs_times)
        cum_initial, cum_integral = associative_scan(binary_operator, (mats_A, mats_B))
        
        init_mean_var = torch.vmap(lambda cum_init : cum_init * init)(cum_initial)
        init_mean, init_var = torch.vmap(lambda mean_var : torch.chunk(mean_var, chunks=2, dim=1))(init_mean_var)
        xs_mean, xs_var = torch.vmap(lambda mean_var : torch.chunk(mean_var, chunks=2, dim=1))(cum_integral)
        
        y_var = (elup(self.y_log_var) + 1e-6)
        
        means = torch.vmap(lambda mean : self.M(torch.vmap(lambda X: E(X))(mean)))(xs_mean + init_mean)
        stds = torch.vmap(lambda std : self.M(torch.vmap(lambda X: E(X))(std)))(torch.sqrt(xs_var + init_var + y_var))
        
        return means, stds, alphas

    def forward(self, Z_context, times):
        
        times = times[..., None]

        # get init_mean & init_var
        E = self.E.weight.data
        init_mean = E.t() @ self.init_mean
        init_var = E.t() @ (self.init_sigma * (elup(self.init_log_var) + 1e-6).diag_embed()) @ E
        init_var = torch.diag(init_var)
        init_mean_var = torch.cat([init_mean, init_var], dim=0)

        # compute means, stds, alphas
        means, stds, alphas = self.parallel_compute(init_mean_var, self.E, Z_context, times)

        return means, stds, alphas

