import math
import matplotlib.pyplot as plt
from functools import partial
import itertools
import numpy as np
from tqdm import tqdm
from typing import *
from pylab import cm

import torch
from torch import Tensor, vmap
from torch.func import grad_and_value, jacrev
import torch.nn as nn
from torch.nn.functional import leaky_relu, sigmoid
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.categorical import Categorical
from torch.distributions import MultivariateNormal, Normal, Independent
from torch.distributions import Distribution

from unet import *
from sfa_lds import AttentionPooling

class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: List[int] = [64, 64],
        fct=nn.Tanh(),
        batch_norm=False,
        # fct=ScaledSigmoid()
    ):
        layers = []

        for a, b in zip(
            (in_features, *hidden_features),
            (*hidden_features, out_features),
        ):
            if batch_norm:
                layers.extend([nn.Linear(a, b), nn.BatchNorm1d(b), fct])

            else:
                layers.extend([nn.Linear(a, b), fct])

        if batch_norm:
            super().__init__(*layers[:-2])
        else:
            super().__init__(*layers[:-1])



class LatentDynamicalSystem(nn.Module):
    def __init__(self, latent_dim):
        """
        Initialize the latent dynamical system prior.
        
        Parameters:
        - latent_dim: Dimensionality of the latent space.
        """
        super(LatentDynamicalSystem, self).__init__()
        self.latent_dim = latent_dim
        self.mu1 = torch.zeros(latent_dim)  # Mean vector for z1
        self.log_Q1 = torch.log(torch.ones(latent_dim)*0.5) # torch.zeros(latent_dim)  # Log-diagonal of covariance matrix Q1
        self.A = torch.eye(latent_dim)  # Transition matrix A
        self.log_Q = torch.log(torch.ones(latent_dim)*0.5)  # Log-diagonal of covariance matrix Q for subsequent steps


    def _get_distribution(self, mu, log_var):
        """Create a diagonal Gaussian distribution given mean and log variance."""
        # restrict mu to (-1,1), restrict log_var to negative

        scale = torch.diag_embed(torch.exp(0.5 * log_var))
        dists = MultivariateNormal(mu, scale_tril=scale)
        return dists

    def rsample(self, n, S, device):
        """
        Generate a sequence of latent variables z_1, ..., z_T using reparameterization trick.
        Parameters:
        - T: Number of time steps.
        Returns:
        - zs: Sequence of latent variables (shape: [T, latent_dim])
        """
        zs = torch.zeros((S, n, self.latent_dim), device=device)
        # First latent variable z1 ~ N(mu1, Q1)
        dist_z1 = self._get_distribution(self.mu1.to(device), self.log_Q1.to(device))
        zs[0] = dist_z1.rsample((n,))  # Use rsample for reparameterization
        # Latent variables z_s+1 ~ N(A * z_s, Q)
        for s in range(1, S):
            dist_zs = self._get_distribution(
                torch.einsum("ij, nj -> ni", self.A.to(device), zs[s-1].clone()), self.log_Q.to(device))
            # print(dist_zs.rsample().shape)
            zs[s] = dist_zs.rsample()  # Use rsample for reparameterization
        
        return zs # (S, n, latent_dim)

    def sample(self, n, S, device):
        """
        Generate a sequence of latent variables z_1, ..., z_T using reparameterization trick.
        Parameters:
        - T: Number of time steps.
        Returns:
        - zs: Sequence of latent variables (shape: [T, latent_dim])
        """
        zs = torch.zeros((S, n, self.latent_dim)).to(device)
        # First latent variable z1 ~ N(mu1, Q1)
        dist_z1 = self._get_distribution(self.mu1.to(device), self.log_Q1.to(device))
        zs[0] = dist_z1.sample((n,)).to(device) 
        # Latent variables z_s+1 ~ N(A * z_s, Q)
        for s in range(1, S):
            mus = self.A.expand(n,self.latent_dim,self.latent_dim) @ zs[s-1].unsqueeze(-1)
            dist_zs = self._get_distribution(
                mus.squeeze().to(device), 
                self.log_Q.to(device))
            zs[s] = dist_zs.sample().to(device)  # Use rsample for reparameterization

        return zs # (S, n, latent_dim)

    def _log_prob(self, z_s, z_s_minus_1):
        """
        Evaluate the log probability of z_s given z_{s-1}.
        
        Parameters:
        - z_s: The current latent variable z_s (shape: [latent_dim])
        - z_s_minus_1: The previous latent variable z_{s-1} (shape: [latent_dim])
        
        Returns:
        - log_prob: Log probability of z_s given z_{s-1}.
        """

        if z_s_minus_1 is not None:
            # Compute the means A * z_{s-1} for all steps in parallel
            means = torch.einsum('ij,ni->nj', self.A, z_s_minus_1)  # Shape: [n, T-1, latent_dim]
            # Get the conditional distribution P(z_s | z_{s-1}) for all steps
            dist_zs = self._get_distribution(means, self.log_Q)
            # Compute log-probability of z_s under P(z_s | z_{s-1})
            log_prob = dist_zs.log_prob(z_s) # [n,]
        else:
            dist_zs = self._get_distribution(self.mu1, self.log_Q1)
            log_prob = dist_zs.log_prob(z_s)

        return log_prob.squeeze()

    def log_prob(self, z, S):
    	out = self._log_prob(z[0], None)

    	for s in range(1, S):
    		out += self._log_prob(z[s], z[s-1])

    	return out

class sGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, S, bidirectional=True):
        super(sGRU, self).__init__()
        self.S = S
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.positional_encoding = timestep_embedding
        self.bidirectional = bidirectional
        self.gru = nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional, batch_first=False)

    def forward(self, x):
        batch_size = x.shape[1]
        if self.bidirectional:
            # Initial hidden state
            h0 = torch.zeros(2*self.num_layers, batch_size, self.hidden_size).to(x.device)
            # Forward propagate through RNN
            _, hs = self.gru(x, h0)
            # compute embeded x
            emb = torch.concatenate([hs[0], hs[1]], dim=-1).view(batch_size, self.hidden_size*2)
        else:
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
            # Forward propagate through RNN
            _, hs = self.gru(x, h0)

            emb = hs
        return emb

    def recurse(self, xs, hsminus):
        batch_size = xs.shape[0]
        xs = xs.unsqueeze(0).clone()

        # with new sample coming in compute the new embedding
        if not self.bidirectional:
            if hsminus is not None:
                _, hs = self.gru(xs, hsminus)
            else:
                _, hs = self.gru(xs, hsminus)
            emb = hs # should remove .clone??
        return emb.detach()

class LLKc(nn.Module):
    def __init__(self, features: int, context: int, S: int, **kwargs):
        super().__init__()
        self.hyper = MLP(context, features*2, **kwargs)
        self.S = S
        self.features = features

    def forward(self, c: Tensor, min_variance=1e-6):
        phi = self.hyper(c)
        mu, sigma = phi.chunk(2, dim=-1)

        return Independent(Normal(mu, F.softplus(sigma)+ min_variance), 1)

    def sample(self, c):
    	return self(c).sample()

    def sampleS(self, cS):
        out = [self.sample(cS[s]) for s in range(self.S)]
        return torch.stack(out, dim=0) 


    def rsample(self, c):
    	return self(c).rsample()

    def rsampleS(self, cS):

        out = [self.sample(cS[s]) for s in range(self.S)]
        return torch.stack(out, dim=0)

    def log_prob(self, c, x):
    	return self(c).log_prob(x)

    def log_probS(self, cS, xS):
        log_lik = 0
        for s in range(self.S):
            log_lik = log_lik + self.log_prob(cS[s], xS[s])
        return log_lik


class LLK(nn.Module):
    def __init__(self, features: int, context: int, S: int, dsemb: int = 8, freqs: int = 2, **kwargs):
        super().__init__()
        self.S = S
        self.dsemb = dsemb
        self.hyper = MLP(self.dsemb + context, features*2, **kwargs),

        freqs_s = torch.exp(torch.linspace(0, torch.log(torch.tensor(S)), self.dsemb // 2))

        self.register_buffer('freqs_s', 2 * torch.pi / S * freqs_s)

    def forward(self, c: Tensor, s: int, min_variance=1e-6):
        s = self.freqs_s * s[..., None]
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        semb = semb.expand(*c.shape[:-1], -1)

        phi = self.hyper(torch.cat((c, semb), dim=-1))
        mu, sigma = phi.chunk(2, dim=-1)
        return Independent(Normal(mu, F.sigmoid(sigma)), 1)

    def sample(c, s):
        return self(c, s).sample()

    def rsample(c, s):
        return self(c, s).rsample()

    def log_prob(c, s, x):
        return self(c, s).log_prob(x)


class GaussRecogNet(nn.Module):
    """Gaussian meanfield Variational Posterior Approximation to p(z_s|x_{[S]})"""
    def __init__(self, features: int, context: int, S: int, dsemb: int = 8, freqs: int = 2, nonlinearity="tanh", **kwargs):
        super(GaussRecogNet, self).__init__()
        self.S = S
        self.dsemb = dsemb
        self.num_hidden = context
        # self.emb = sRNN(context, self.num_hidden, S=self.S, num_layers=2, freqs=4, nonlinearity=nonlinearity) 
        self.rnnx = sGRU(context, self.num_hidden, S=self.S, num_layers=2, freqs=freqs) 

        self.hyper = MLP(self.dsemb+self.num_hidden*2, features*2, **kwargs)

        freqs_s = torch.exp(torch.linspace(0, torch.log(torch.tensor(self.F)), self.dsemb // 2))

        self.register_buffer('freqs_s', 2 * torch.pi / self.F * freqs_s)

    def forward(self, c: Tensor, s: int):
        s = self.freqs_s * s[..., None]
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        semb = semb.expand(*c.shape[:-1], -1)

        cemb = self.rnnx(c) # encode the entire observed string to latent vector

        phi = self.hyper(torch.cat((cemb, semb), dim=-1))
        mu, log_sigma = phi.chunk(2, dim=-1)

        return Independent(Normal(mu.to(c.device), log_sigma.exp().to(c.device)), 1)

    def sample(self, c):
        zc = torch.zeros((self.S, len(c), features), device=c.device)
        zc[0] = self(c, torch.tensor([0], dtype=torch.long, device=c.device)).sample()
        for s in range(1, self.S):
        	zc[s] = self(c, torch.tensor([s], dtype=torch.long, device=c.device)).sample()
        return zc

    def rsample(self, c):
        zc = torch.zeros((self.S, len(c), features), device=c.device)
        zc[0] = self(c, torch.tensor([0], dtype=torch.long, device=c.device)).rsample()
        for s in range(1, self.S):
        	zc[s] = self(c, torch.tensor([s], dtype=torch.long, device=c.device)).rsample()
        return zc

    def log_prob(self, c, z):
        _log_prob = self(c, torch.tensor([0], dtype=torch.long, device=c.device)).log_prob(z[0])
        for s in range(1, self.S):
        	_log_prob += self(c, torch.tensor([s], dtype=torch.long, device=c.device)).log_prob(z[s])

        return _log_prob


class fullGaussRecogNet(nn.Module):
    """Gaussian meanfield Variational Posterior Approximation to p(z_s|x_{[S]})"""
    def __init__(self, features: int, context: int, S: int, F:int, dsemb: int = 8, freqs: int = 2, nonlinearity="tanh", num_layers=1, emb_dim=10, num_hidden=32,**kwargs):
        super(fullGaussRecogNet, self).__init__()
        self.S = S
        self.F = F
        self.dsemb = dsemb
        self.features = features
        self.context = context
        self.num_layers = num_layers
        self.num_hidden = emb_dim
        self.num_hidden_z = num_hidden
        self.att = AttentionPooling(self.num_hidden)
        self.num_embx = self.num_hidden

        self.rnnz = sGRU(features, self.num_hidden_z, S=self.S, num_layers=self.num_layers, bidirectional=False)

        self.hyper = MLP(self.num_embx+self.num_hidden_z*self.num_layers+self.dsemb, features*2, **kwargs)
        
        self.semb_project = nn.Linear(self.dsemb, self.num_hidden)

        freqs_s = torch.exp(torch.linspace(0, torch.log(torch.tensor(self.F)), self.dsemb // 2))

        self.register_buffer('freqs_s', 2 * torch.pi / self.F * freqs_s)

    def forward(self, cemb: Tensor, zcemb: Tensor, s: int, min_variance=1e-6):
        s = self.freqs_s * s[..., None]
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        semb = semb.expand(*cemb.shape[:-1], -1)
        # print("semb", semb.shape)
        n = cemb.shape[0]
        # print(cemb.shape)

        if zcemb is None:
            phi = self.hyper(torch.cat((cemb, semb, torch.zeros((n,self.num_hidden_z*self.num_layers)).to(device=cemb.device)), dim=-1))
        else:
            # print("zcemb", zcemb.shape)
            phi = self.hyper(torch.cat((cemb, semb, zcemb), dim=-1))
        mu, sigma = phi.chunk(2, dim=-1)

        return Independent(Normal(mu, F.softplus(sigma) + min_variance), 1)

    def _add_semb(self, c, indices): # indices of shape (n,1)
        s = self.freqs_s.unsqueeze(0).unsqueeze(0) * indices.unsqueeze(-1)
        # print("s", s.shape)
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        return c + self.semb_project(semb)


    def sample(self, c, indices=None):
        # print("c", c.shape)
        if indices is None:
            indices = torch.arange(self.S, device=c.device)/self.S

        n = c.shape[1]
        xcemb, _ = self.att(c) # (n, num_hidden)

        zc = torch.zeros((self.S, n, self.features), device=c.device)
        zSemb = [None]

        zc[0] = self(xcemb, zSemb[0], indices[0]).sample()
        for s in range(1, self.S):
            zcemb = self.rnnz.recurse(zc[s-1], zSemb[-1])
            zSemb.append(zcemb)
            zc[s] = self(xcemb, zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), indices[s]).sample()
            
        return zc

    def rsample(self, c, indices=None):
        # print("c", c.shape)
        if indices is None:
            indices = torch.arange(self.S, device=c.device)/self.S

        n = c.shape[1]
        xcemb, _ = self.att(c)

        zc = torch.zeros((self.S, n, self.features), device=c.device)
        zSemb = [None]

        zc[0] = self(xcemb, zSemb[0], indices[0]).rsample()
        for s in range(1, self.S):
            zcemb = self.rnnz.recurse(zc[s-1], zSemb[-1])
            zSemb.append(zcemb)
            zc[s] = self(xcemb, zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), indices[s]).rsample()
            
        return zc

    def predictF(self, F, zF, c, indices=None): # F>S
        n = c.shape[1]
        if indices is None:
            indices = torch.arange(F, device=c.device)/F
        xcemb, _ = self.att(c)

        zFemb = [None]
        z0F = zF.clone().detach()
        z0F[0] = self(xcemb, zFemb[-1], indices[0]).sample()
        # z0S[0] = z0S0 / torch.norm(z0S0, dim=-1, keepdim=True)
        for f in range(1,F):
            zcemb = self.rnnz.recurse(z0F[f-1], zFemb[-1])
            zFemb.append(zcemb)
            z0F[f] = self(xcemb, zFemb[-1].view(-1, self.num_hidden_z*self.num_layers), indices[f]).sample()

        return z0F

    def log_prob(self, c, z, indices=None):
        if indices is None:
            indices = torch.arange(self.S, device=c.device)/self.S

        n = c.shape[1]
        xcemb, _ = self.att(c)
        zSemb = [None]

        _log_prob = self(xcemb, zSemb[-1], indices[0]).log_prob(z[0])
        
        for s in range(1, self.S):
            zcemb = self.rnnz.recurse(z[s-1], zSemb[-1])
            zSemb.append(zcemb)
            _log_prob = _log_prob + self(xcemb, zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), indices[s]).log_prob(z[s])
        return _log_prob


class ELBOc(nn.Module):
    def __init__(self, p: Distribution, q: Distribution, prior: Distribution, S: int, features, alpha=0.001, const=True):
        super().__init__()

        self.pxz = p
        self.qzx = q
        self.prior = prior
        self.S = S
        self.const=const
        self.alpha=alpha

    def forward(self, c, indices=None):
        if indices is None:
            indices = torch.arange(self.S, device=c.device)/self.S

        zc = self.qzx.rsample(c, indices)
        post_log_prob = self.qzx.log_prob(c, zc, indices)
        prior_log_prob = self.prior.log_prob(zc, self.S)

        loss_kl = (prior_log_prob - post_log_prob)

        loss_llk = 0
        reg_sq = 0
        if self.const:
            for s in range(self.S):
                loss_llk = loss_llk + self.pxz(zc[s]).log_prob(c)
                if s > 1:
                    reg_sq += ((zc[s] - zc[s-1]) - (zc[s-1] - zc[s-2]))**2
        else:
            for s in range(self.S):
                loss_llk = loss_llk + self.pxz(zc[s], indices[s]).log_prob(c)
                if s > 1:
                    reg_sq += ((zc[s] - zc[s-1]) - (zc[s-1] - zc[s-2]))**2

        # print("kl loss", -loss_kl.mean())
        # print("llk loss", -loss_llk.mean())
        loss = -loss_kl.mean()/self.S -loss_llk.mean()/self.S * 0.01

        reg_loss2 = reg_sq.mean()/self.S # smoothness
        beta2 = loss.detach() / (torch.abs(reg_loss2.detach())+1e-8)

        return loss + reg_loss2 * self.alpha * beta2








