import math
import matplotlib.pyplot as plt
from functools import partial
import itertools
import numpy as np
from tqdm import tqdm
from typing import *
from pylab import cm

import torch
from torch import Tensor, vmap
from torch.func import grad_and_value, jacrev, vmap
import torch.nn as nn
from torch.nn.functional import leaky_relu, sigmoid, softmax
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from torch.distributions import Dirichlet, Categorical, Normal, Uniform, Bernoulli
from torchdiffeq import odeint_adjoint

from zuko.utils import odeint
from zuko.distributions import DiagNormal
from unet import *

from dataloader.dataloader_mnist import inv_transform

# from dataloader.dataloader_pinwheel import *


torch.set_printoptions(precision=3)
torch.set_default_dtype(torch.float64)


def sum_except_batch(x):
    return x.view(x.size(0), -1).sum(-1)

def log_normal(x: Tensor) -> Tensor:
    return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2

def first_eigen_proj(x):
    # Step 1: Compute the covariance matrix
    x_centered = x - x.mean(0, keepdims=True)
    cov_matrix = np.cov(x_centered, rowvar=False)

    # Step 2: Perform Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 3: Identify the first Eigen direction
    first_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]

    # Step 4: Project the array onto the first Eigen direction
    projected_array = np.dot(x, first_eigenvector)

    return projected_array



class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: List[int] = [64, 64],
        fct=nn.Tanh(),
        batch_norm=False
        # fct=ScaledSigmoid()
    ):
        layers = []

        for a, b in zip(
            (in_features, *hidden_features),
            (*hidden_features, out_features),
        ):  
            if batch_norm:
                layers.extend([nn.Linear(a, b), nn.BatchNorm1d(b), fct])
            else:
                layers.extend([nn.Linear(a, b), fct])

        super().__init__(*layers[:-1])


# Sample from the Gumbel-Softmax distribution and optionally discretize.
class GumbelSoftmax(nn.Module):

    def __init__(self, c_dim, temperature=1.0, hard=False):
        super(GumbelSoftmax, self).__init__()
        # self.logits = nn.Linear(f_dim, c_dim)
        # self.f_dim = f_dim
        self.c_dim = c_dim
        self.temperature = temperature
        self.hard = hard
        # self.device = device
     
    def sample_gumbel(self, shape, eps=1e-20):
        U = torch.rand(shape)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits):
        y = logits + self.sample_gumbel(logits.size())
        return F.softmax(y / self.temperature, dim=-1)

    def sample(self, logits):
        """
        ST-gumple-softmax
        input: [*, n_class]
        return: flatten --> [*, n_class] an one-hot vector
        """
        #categorical_dim = 10
        y = self.gumbel_softmax_sample(logits)

        if not self.hard:
            return y

        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        # Set gradients w.r.t. y_hard gradients w.r.t. y
        y_hard = (y_hard - y).detach() + y
        return y_hard 

    def forward(self, logits):
        # logits = self.logits(x).view(-1, self.c_dim)
        prob = F.softmax(logits, dim=-1)
        # y = self.sample(logits)
        # return logits, prob, y
        return prob


class GaussianNet(nn.Module):
    def __init__(self, K, z_features, x_features=0, **kwargs):
        super().__init__()

        self.K = K
        self.hyper =  nn.Sequential(
            MLP(K + x_features, z_features*2, **kwargs),
            # nn.Tanh()
            )
        self.z_features = z_features
        # self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
        
    def forward(self, pi:Tensor, x:Tensor):
        if x is None:
            phi = self.hyper(pi)
        else:
            phi = self.hyper(torch.cat((pi, x), dim=-1))

        mu, log_sigma = phi.chunk(2, dim=-1)
        return DiagNormal(mu, log_sigma.exp()) # could potentially fix variance = 1.

    def rsample(self, pi, x):
        return self(pi, x).rsample()

    def log_prob(self, pi, x, z):
        dist = self(pi, x)
        return dist.log_prob(z)


class LLKNet(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, **kwargs):
        super().__init__()

        self.hyper = nn.Sequential(
            MLP(z_features, x_features, **kwargs)
        )
        self.x_features = x_features

    def forward(self, c):
        mu = self.hyper(c)
        return DiagNormal(mu, torch.ones(self.x_features).to(c.device))

    def sample(self, c, n=1):
    	dist = self(c)
    	return dist.sample((n,))

    def log_prob(self, c, x):
        return self(c).log_prob(x)


class CatNet(nn.Module):
    def __init__(self, x_features: int, k: int, temp=1., hard=False, **kwargs):
        super().__init__()

        self.temp = temp
        self.hard = hard

        self.hyper = nn.Sequential(MLP(x_features, k, **kwargs))

    def _sample_gumbel(self, shape, device, eps=1e-20):
        U = torch.rand(shape, device=device)
        return -torch.log(-torch.log(U + eps) + eps)

    def _gumbel_softmax_sample(self, logits):
        y = logits + self._sample_gumbel(logits.size(), logits.device)
        return F.softmax(y / self.temp, dim=-1)

    def _sample(self, logits):
        """
        ST-gumple-softmax
        input: [*, n_class]
        return: flatten --> [*, n_class] an one-hot vector
        """
        #categorical_dim = 10
        y = self._gumbel_softmax_sample(logits)

        if not self.hard:
            return y

        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        # Set gradients w.r.t. y_hard gradients w.r.t. y
        y_hard = (y_hard - y).detach() + y
        return y_hard 

    def _forward(self, logits):
        # logits = self.logits(x).view(-1, self.c_dim)
        prob = F.softmax(logits, dim=-1)
        # y = self.sample(logits)
        # return logits, prob, y
        return prob

    def forward(self, x: Tensor):
        h = self.fc(x)
        out = self._forward(h)
        return out

    def rsample(self, x: Tensor, logits=None) -> Tensor:
        if x is None and logits is not None:
            zt = self._sample(logits)
        else:
            logits = self.hyper(x)
            zt = self._sample(logits)
        prob = F.softmax(logits, dim=-1)
        return zt, logits, prob



class ELBO(nn.Module):
    def __init__(self, pxz, qzpix, qpix, priorz, k=10):
        super().__init__()
        self.pxz = pxz
        self.qzpix = qzpix
        self.qpix = qpix
        self.priorz = priorz
        self.k = k
    @staticmethod
    def entropy(logits, target):
    	log_q = F.log_softmax(logits, dim=-1)
    	return -torch.mean(torch.sum(target * log_q, dim=-1))


    def forward(self, x):
        # posterior sample of z and pi
        pic, logits, prob = self.qpix.rsample(x)
        zc = self.qzpix.rsample(pic, x)
        loss1 = (self.qzpix.log_prob(pic, x, zc) - self.priorz.log_prob(pic, None, zc)).mean() # * 0.1
        # print("loss1", loss1)
        loss2 = (-self.entropy(logits, prob) - torch.log(torch.Tensor([1/self.k]))).mean()
        loss3 = -self.pxz.log_prob(zc, x).mean() # * 0.1
        return 0.1 * loss1 + loss2 + 0.1 * loss3







