import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.distributions import Bernoulli, Distribution, Independent, Normal
from functools import partial
import itertools
import numpy as np
from typing import *
from torchdiffeq import odeint_adjoint
import zuko
from zuko.distributions import DiagNormal
from dataloader.dataloader_pinwheel import *

import gc
gc.collect()



torch.set_printoptions(precision=3)
torch.set_default_dtype(torch.float64)

class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: List[int] = [64, 64],
        fct=nn.Tanh(),
    ):
        layers = []

        for a, b in zip(
            (in_features, *hidden_features),
            (*hidden_features, out_features),
        ):
            layers.extend([nn.Linear(a, b), fct])

        super().__init__(*layers[:-1])


class ELBO(nn.Module):
    def __init__(
        self,
        decoder: zuko.flows.LazyDistribution,
        encoder: zuko.flows.LazyDistribution,
        prior: zuko.flows.LazyDistribution,
    ):
        super().__init__()

        self.decoder = decoder
        self.encoder = encoder
        self.prior = prior

    def forward(self, x: Tensor) -> Tensor:
        q = self.encoder(x)
        z = q.rsample().to(x.device)
        loss = -self.decoder(z).log_prob(x).mean() * 0.08 # * 0.08
        # loss -= (self.prior.log_prob(z) - q.log_prob(z)).mean() * 0.001 # 0.5
        loss -= (self.prior.log_prob(z) - q.log_prob(z)).mean() # * 0.01
        return loss

class LLKNet(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, **kwargs):
        super().__init__()

        self.hyper = nn.Sequential(
            MLP(z_features, 2*x_features, **kwargs)
        )
        self.x_features = x_features
        # 
    def forward(self, c):
        phi = self.hyper(c)
        mu, sigma = phi.chunk(2, dim=-1)
        return DiagNormal(mu, F.sigmoid(sigma).to(c.device)+1e-6)

    def sample(self, c, n=1):
        dist = self(c)
        return dist.sample((n,))

    def log_prob(self, c, x):
        return self(c).log_prob(x)


class GaussianNet(zuko.flows.LazyDistribution):
    def __init__(self, features: int, context: int, output_activation=None, **kwargs):
        super().__init__()
        self.output_activation = output_activation
        self.hyper = MLP(context, 2*features, **kwargs)

    def forward(self, c: Tensor) -> Distribution:
        # c = c.flatten(start_dim=1)
        phi = self.hyper(c)
        if self.output_activation is not None:
            phi = self.output_activation(phi)
        mu, log_sigma = phi.chunk(2, dim=-1)

        return Independent(Normal(mu, F.softplus(log_sigma)+1e-6), 1) 


    def rsample(self, c):
        # c = c.flatten(start_dim=1)
        return self(c).rsample()

    def sample(self, c):
        # c = c.flatten(start_dim=1)
        return self(c).sample()

    def log_prob(self, c, z):
        # c = c.flatten(start_dim=1)
        return self(c).log_prob(z)




class CNNEncoder(nn.Module):
    """
    Provided CNNEncoder class for context.
    Assumes the probabilistic forward method is the intended one.
    This architecture results in a 128x4x4 feature map before the FC layer
    for a 28x28 input image.
    """
    def __init__(self, in_ch: int, z_dim: int, fct): # fct is an activation function module, e.g., nn.ReLU()
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_ch, 32, 4, stride=2, padding=1),   # For 28x28 input -> (B, 32, 14, 14)
            fct,
            nn.Conv2d(32, 64, 4, stride=2, padding=1),      # For 14x14 input -> (B, 64, 7, 7)
            fct,
            nn.Conv2d(64, 128, 3, stride=2, padding=1),     # For 7x7 input -> (B, 128, 4, 4)
            fct
        )
        # The feature map size before FC layer is 128 * 4 * 4 for a 28x28 input.
        self.fc = nn.Linear(128 * 4 * 4, z_dim * 2)

    # This is the probabilistic forward method.
    def forward(self, x: Tensor) -> Distribution: # x is the input image
        h = self.conv_layers(x)
        h = h.view(h.size(0), -1) # Flatten
        phi = self.fc(h)
        mu, log_sigma = phi.chunk(2, dim=-1) # Split into mean and log_sigma

        return Independent(Normal(mu.to(x.device), F.softplus(log_sigma).to(x.device)), 1)

    def rsample(self, x: Tensor) -> Tensor: # x is the input image
        """Samples from the distribution using the reparameterization trick."""
        return self(x).rsample()

    def sample(self, x: Tensor) -> Tensor: # x is the input image
        """Samples from the distribution."""
        return self(x).sample()

    def log_prob(self, x: Tensor, z: Tensor) -> Tensor: # x is input image, z is the latent sample
        """Computes the log probability of z under the distribution q(z|x)."""
        return self(x).log_prob(z)


class CNNDecoder(nn.Module):
    """
    Probabilistic CNNDecoder.
    Outputs a distribution over images given a latent variable z.
    Adapted for a target image size of 28x28 (e.g., MNIST) without using F.interpolate.
    """
    def __init__(self, out_ch: int, z_dim: int, fct, target_size=28, output_activation=None):
        # out_ch: number of channels in the output image (e.g., 1 for grayscale, 3 for RGB)
        # z_dim: dimension of the latent space
        # fct: activation function module (e.g., nn.ReLU())
        # target_size: the H and W of the output image (must be 28 for this specific architecture)
        # output_activation: activation to apply to the mean (e.g., nn.Sigmoid() for [0,1] data)
        super().__init__()
        
        if target_size != 28:
            # This specific architecture is designed to output 28x28 without interpolation.
            # If a different target_size is needed, the ConvTranspose2d layers would need adjustment.
            raise ValueError(f"This CNNDecoder version is designed for target_size=28, but got {target_size}. "
                             "To avoid interpolation, layer parameters are fixed for this output size.")
                             
        self.target_size = target_size # Should be 28
        self.out_ch = out_ch
        self.z_dim = z_dim

        # Fully connected layer to project z and reshape for deconvolution
        # The output size 128 * 4 * 4 matches the encoder's feature map before its FC layer
        self.fc = nn.Linear(z_dim, 128 * 4 * 4)

        # Deconvolutional layers to upsample the feature map to 28x28
        # and output the distribution parameters.
        self.deconv_layers = nn.Sequential(
            # Input: (B, 128, 4, 4)
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0),  # -> (B, 64, 8, 8)
            fct, # Activation function
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=2, output_padding=0),   # -> (B, 32, 14, 14)
            fct,  # Activation function
            nn.ConvTranspose2d(32, self.out_ch * 2, kernel_size=4, stride=2, padding=1, output_padding=0) # -> (B, out_ch * 2, 28, 28)
        )

        # Store the output activation if provided (e.g., nn.Sigmoid or nn.Tanh)
        # This will be applied to the mean part of the parameters.
        self.output_activation = output_activation

    def forward(self, z: Tensor) -> Distribution:  # z: (B, z_dim)
        # Project and reshape z
        h = self.fc(z)
        h = h.view(h.size(0), 128, 4, 4) # Reshape to (B, 128, 4, 4)

        params = self.deconv_layers(h)

        mu_raw, log_sigma = torch.chunk(params, 2, dim=1)

        # Apply activation to the raw mean if specified (e.g., to bound pixel values)
        if self.output_activation is not None:
            mu = self.output_activation(mu_raw)
        else:
            mu = mu_raw # Use raw mean if no activation

        sigma = F.softplus(log_sigma) # sigma = exp(log_sigma)
        mu = mu.to(z.device)
        sigma = sigma.to(z.device)

        return Independent(Normal(loc=mu, scale=sigma), 3) # 3 for C, H, W dimensions

    def rsample(self, z: Tensor) -> Tensor:
        """Samples from the output distribution p(x|z) using the reparameterization trick."""
        return self(z).rsample()

    def sample(self, z: Tensor) -> Tensor:
        """Samples from the output distribution p(x|z)."""
        return self(z).sample()

    def log_prob(self, z: Tensor, x_target: Tensor) -> Tensor:
        """
        Computes the log probability of a target image x_target under the
        distribution p(x|z) generated from z.

        Args:
            z (Tensor): The latent variable, shape (B, z_dim).
            x_target (Tensor): The target image, shape (B, out_ch, target_size, target_size).

        Returns:
            Tensor: The log probability for each item in the batch, shape (B,).
        """
        return self(z).log_prob(x_target)




def log_normal(x: Tensor) -> Tensor:
    return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2

class mCNF(nn.Module):
    def __init__(self, features: int, freqs: int = 3, device='cpu', **kwargs):
        super().__init__()
        self.features = features
        self.net = MLP(2 * freqs + features, features, **kwargs)
        self.device=device

        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor) -> Tensor:
        t = self.freqs * t[..., None]
        t = torch.cat((t.cos(), t.sin()), dim=-1)
        t = t.expand(*x.shape[:-1], -1).to(self.device)

        return self.net(torch.cat((t, x), dim=-1))

    def decode(self, x:Tensor, t=1.) -> Tensor:
        # return odeint(self, z, 1.0, 0.0, phi=self.parameters())

        xt = odeint_adjoint(
            self,
            x, 
            torch.tensor([t, 0.], device=x.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return xt

    def log_prob(self, x: Tensor, t=0.) -> Tensor:

        I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)

        def augmented(t: Tensor, state) -> Tensor:
            x, ladj = state
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x)

            jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0]).to(device=x.device)
        xt, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.tensor([t, 1.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)
        # print("priorlog", priorlog.shape)
        return log_normal(xt[-1]) + ladj[-1] * 1e2

    def sample(self, n):
        p = DiagNormal(torch.zeros(self.features).to(self.device), torch.ones(self.features).to(self.device))
        z0 = p.sample(n).to(self.device)
        return self.decode(z0)

class CNF(nn.Module):
    def __init__(self, features: int, freqs: int = 3, device='cpu', **kwargs):
        super().__init__()
        self.features = features
        self.net = MLP(2 * freqs + features, features, **kwargs)
        self.device=device

        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor) -> Tensor:
        t = self.freqs * t[..., None]
        t = torch.cat((t.cos(), t.sin()), dim=-1)
        t = t.expand(*x.shape[:-1], -1).to(self.device)

        return self.net(torch.cat((t, x), dim=-1))

    def decode(self, z: Tensor, x:Tensor) -> Tensor:
        # return odeint(self, z, 1.0, 0.0, phi=self.parameters())
        if t is None:
            t = 1.
        x = x.clone().detach().requires_grad_(True)

        zt = odeint_adjoint(
            partial(self, pi=pi, x=x), 
            z, 
            torch.tensor([0., t], device=x.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return zt

    def log_prob(self, z: Tensor, x: Tensor, t, prior) -> Tensor:

        I = torch.eye(z.shape[-1], dtype=z.dtype, device=z.device)
        I = I.expand(*z.shape, z.shape[-1]).movedim(-1, 0)

        x = x.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            z, ladj = state
            with torch.enable_grad():
                z = z.requires_grad_()
                dz = self(t, z, x)

            jacobian = torch.autograd.grad(dz, z, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dz, trace * 1e-2

        ladj = torch.zeros_like(z[..., 0]).to(device=x.device)
        zt, ladj = odeint_adjoint(
            augmented, 
            (z, ladj), 
            torch.tensor([t, 1.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)
        priorlog = prior.log_prob(zt[-1])
        # print("priorlog", priorlog.shape)
        return priorlog + ladj[-1] * 1e2

    def sample(self, n):
        p = DiagNormal(torch.zeros(self.features).to(self.device), torch.ones(self.features).to(self.device))
        z0 = p.sample(n).to(self.device)
        return self.decode(z0)



if __name__ == "__main__":
    torch.manual_seed(1234)
    batch_size = 125
    num_per_class = 10_000
    dat_dir = "data"
    num_classes = 5
    dataloader = DataLoader(PINWHEEL(num_per_class, dat_dir, num_classes=num_classes, gen=True, plot=False), batch_size=batch_size, shuffle=True)
    z_feature_dim = 1
    x_feature_dim = 2
    n_gen = 2000
    n_epoch= 20 # 400

    # we use Gaussian as both decoder and encoder
    encoder = GaussianNet(z_feature_dim, x_feature_dim, hidden_features=[1024]*2, fct=nn.Tanh())
    decoder = GaussianNet(x_feature_dim, z_feature_dim, hidden_features=[1024]*2, fct=nn.Tanh())
    # we first use std Gaussian as prior
    # prior = DiagNormal(torch.zeros(z_feature_dim), torch.ones(z_feature_dim))
    prior = CNF(z_feature_dim, hidden_features=[32, 32]*2)
    # prior = zuko.flows.MAF(features=z_feature_dim, transforms=4, hidden_features=(256,256))

    # Training
    loss = ELBO(encoder, decoder, prior)
    optimizer = torch.optim.Adam(
        itertools.chain(
                encoder.parameters(),
                decoder.parameters(),
                prior.parameters()
                ), 
        lr=1e-3)

    
    for epoch in (bar := tqdm(range(n_epoch))):
        ls = []
        for i, (x, y) in enumerate(dataloader):
            # emb = y.view(-1,1).to(torch.float32)
            l=loss(x)
            l.backward()

            optimizer.step()
            optimizer.zero_grad()
            ls.append(l.detach())
            print("epoch={}, loss={}".format(epoch, ls[-1]))
        ls = torch.stack(ls)
        bar.set_postfix(loss=ls.mean().item())

        # if epoch % 50 == 1:
        # print("epoch={}, loss={}".format(epoch, l))

    # Sampling
    with torch.no_grad():
        # z1 = torch.randint(0,5,(16384,1)).to(torch.float32)
        z0 = prior.sample((n_gen,))
        x1 = decoder(z0).sample((1,)).squeeze()
        # x1 = decoder(z0).mean
        z1 = encoder(x1).sample((1,))

    if z_feature_dim > 1:
        # get the projection of z onto its first eigenvector direction
        z1 = first_eigen_proj(z1.squeeze())
    else:
        z1 = z1.squeeze()

    plt.figure(dpi=150)
    # plt.hist2d(*x.T, bins=64)
    plt.scatter(x1[:,0], x1[:,1], c=z1, marker=".", cmap=plt.colormaps['gist_rainbow'])
    plt.colorbar()
    plt.show()
    plt.close()









