import math
import matplotlib.pyplot as plt
from functools import partial
import itertools
import numpy as np
from tqdm import tqdm
from typing import *
import ot

import torch
from torch import Tensor, vmap
from torch.func import grad_and_value, jacrev
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import leaky_relu, sigmoid
import torch.nn.utils.parametrize as parametrize
from torch.distributions import Dirichlet, Categorical, Normal, Uniform

# from zuko.utils import odeint
from torchdiffeq import odeint_adjoint
from zuko.distributions import DiagNormal

from dataloader.dataloader_pinwheel import *

def log_normal(x: Tensor) -> Tensor:
    return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2

def first_eigen_proj(x):
    # Step 1: Compute the covariance matrix
    x_centered = x - x.mean(0, keepdims=True)
    cov_matrix = np.cov(x_centered, rowvar=False)

    # Step 2: Perform Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 3: Identify the first Eigen direction
    first_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]

    # Step 4: Project the array onto the first Eigen direction
    projected_array = np.dot(x, first_eigenvector)

    return projected_array

class ShiftedTanh(torch.nn.Module):
    def __init__(self, a=1.0, b=1.0):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x):
        return torch.tanh(x) + self.b

class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: List[int] = [64, 64],
        fct=nn.Tanh(),
        batch_norm=False
    ):
        layers = []

        for a, b in zip(
            (in_features, *hidden_features),
            (*hidden_features, out_features),
        ):
            if batch_norm:
                layers.extend([nn.Linear(a, b), nn.BatchNorm1d(b), fct])
            else:
                layers.extend([nn.Linear(a, b), fct])

        super().__init__(*layers[:-1])


class GaussianPrior(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int=2, **kwargs):
        super().__init__()

        self.hyper =  nn.Sequential(
            MLP(x_features+2*freqs, z_features*2, **kwargs),
            )
        self.z_features = z_features
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
        
    def forward(self, x:Tensor, t:Tensor=torch.Tensor([0]), min_variance=1e-6):
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        # piemb = self.emb(pi)
        phi = self.hyper(torch.cat((temb, x), dim=-1))

        mu, sigma = phi.chunk(2, dim=-1)
        return DiagNormal((mu), F.softplus(sigma)+min_variance)

    def rsample(self, x:Tensor,t:Tensor=torch.Tensor([0])):
        dist = self(x, t)
        return dist.rsample()

    def sample(self, x:Tensor, t:Tensor=torch.Tensor([0])):
        dist = self(x, t)
        return dist.sample()

    def log_prob(self, x:Tensor, z:Tensor, t:Tensor=torch.Tensor([0])):
        dist = self(x, t)
        return dist.log_prob(z)


class LLK(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, **kwargs):
        super().__init__()
        # TODO: need more elaborate architecture with the embedding
        # self.net = MLP(2 * freqs + x_features + x_features, x_features, **kwargs)
        hidden_dim = z_features
        self.embz = MLP(z_features, z_features, [32, 32])

        self.fc1 = MLP(2 * freqs + x_features+z_features, x_features, **kwargs)
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        # embed t
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        # print("temb", temb.shape)
        zemb = self.embz(z)

        out = self.fc1(torch.cat((temb, zemb, x), dim=-1))

        return out

    def _forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        out = self.forward(t, x, z)
        return out, out

    def encode(self, x: Tensor) -> Tensor:
        return odeint(self, x, 0.0, 1.0, phi=self.parameters())

    def decode(self, x: Tensor, z: Tensor, t=None) -> Tensor:
        if t is None:
            t = 1.
        z = z.clone().detach().requires_grad_(True)
        xt = odeint_adjoint(
            partial(self, z=z), x, 
            torch.Tensor([t, 0.]), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)
        return xt[-1]

    def log_prob(self, x: Tensor, z: Tensor, t=0) -> Tensor:
        I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)

        z = z.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            x, adj = state
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x, z)

            jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0])
        x0, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.Tensor([t, 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-7, rtol=1e-7)
        return log_normal(x0[-1]) + ladj[-1] * 1e2


class LLK_high(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, hidden_dim=800, **kwargs):
        super().__init__()
        # TODO: need more elaborate architecture with the embedding
        # self.net = MLP(2 * freqs + x_features + x_features, x_features, **kwargs)
        self.hidden_dim = hidden_dim
        self.embx = MLP(x_features, hidden_dim, hidden_features=[800], fct=nn.Softplus())
        self.embz = nn.Linear(z_features, hidden_dim)

        self.fc1 = MLP(2 * freqs + 2*self.hidden_dim, x_features, **kwargs)
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        # embed t
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        # print("temb", temb.shape)
        zemb = self.embz(z)
        xemb = self.embx(x)

        out = self.fc1(torch.cat((temb, zemb, xemb), dim=-1))

        return out

    def _forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        out = self.forward(t, x, z)
        return out, out

    def encode(self, x: Tensor) -> Tensor:
        return odeint(self, x, 0.0, 1.0, phi=self.parameters())

    def decode(self, x: Tensor, z: Tensor, t=None) -> Tensor:
        if t is None:
            t = 1.
        z = z.clone().detach().requires_grad_(True)
        xt = odeint_adjoint(
            partial(self, z=z), x, 
            torch.Tensor([t, 0.]), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)
        return xt[-1]

    def log_prob(self, x: Tensor, z: Tensor, t=0, priory=None) -> Tensor:
        I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)

        z = z.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            x, adj = state
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x, z)

            jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0])
        x0, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.Tensor([t, 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-7, rtol=1e-7)
        return log_normal(x0[-1]) + ladj[-1] * 1e2




class CNF(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, hidden_dim=784, **kwargs):
        super().__init__()

        hidden_dim = hidden_dim
        self.embx = nn.Linear(x_features, hidden_dim)

        self.fc1 = nn.Sequential(
            MLP(2 * freqs + hidden_dim+z_features, z_features, **kwargs),
            )

        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, z:Tensor, x: Tensor) -> Tensor:
        # embed t
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*z.shape[:-1], -1)
        xemb = self.embx(x)
        out = self.fc1(torch.cat((z, temb, xemb), dim=-1))
        return -out

    def encode(self, x: Tensor) -> Tensor:
        return odeint(self, x, 0.0, 1.0, phi=self.parameters())

    def decode(self, z: Tensor, x: Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        x = x.clone().detach().requires_grad_(True)
        zt = odeint_adjoint(
            partial(self, x=x), z, 
            torch.tensor([1., t], device=x.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return zt

    def log_prob(self, z: Tensor, x: Tensor, t, prior) -> Tensor:

        I = torch.eye(z.shape[-1], dtype=z.dtype, device=z.device)
        I = I.expand(*z.shape, z.shape[-1]).movedim(-1, 0)

        x = x.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            z, ladj = state
            with torch.enable_grad():
                z = z.requires_grad_()
                dz = self(t, z, x)

            jacobian = torch.autograd.grad(dz, z, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dz, trace * 1e-2

        ladj = torch.zeros_like(z[..., 0]).to(device=x.device)
        zt, ladj = odeint_adjoint(
            augmented, 
            (z, ladj), 
            torch.tensor([t, 1.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)
        priorlog = prior.log_prob(zt[-1])
        # print("priorlog", priorlog.shape)
        return priorlog + ladj[-1] * 1e2


class FlowMatchingLoss(nn.Module):
    def __init__(self, vt: nn.Module, rt: nn.Module, prior, alpha=0.001, sig_min=1e-4, fixz=False):
        super().__init__()

        self.vt = vt
        self.rt = rt
        self.prior = prior
        self.sig_min = sig_min
        self.alpha = alpha
        self.fixz = fixz

    def forward(self, x: Tensor) -> Tensor:
        _t = torch.rand(1)
        t = torch.ones_like(x[..., 0, None]) * _t
        x0 = torch.randn_like(x)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x0
        ut = (1 - self.sig_min) * x0 - x

        if self.fixz:
            zt = self.rt.forward(xt, _t).rsample()
        else:
            z0 = self.prior.sample((len(x),))
            zt = self.rt.decode(z0, xt, _t)

        fm_loss = (self.vt(t.squeeze(-1), xt, zt) - ut).square().mean()
        return fm_loss 

class FlowMatchingLossCNN(nn.Module):
    def __init__(self, vt: nn.Module, rt: nn.Module, prior, sig_min=1e-4,  alpha=0.001, fixz=False):
        super().__init__()

        self.vt = vt
        self.rt = rt
        self.prior = prior
        self.sig_min = sig_min
        self.fixz = fixz
        self.alpha = alpha

    def forward(self, x: Tensor) -> Tensor:
        _t = torch.rand(1)
        t = torch.ones_like(x[..., 0, None]) * _t
        x0 = torch.randn_like(x)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x0
        ut = (1 - self.sig_min) * x0 - x

        if not self.fixz:
            z0 = self.prior.sample((len(x),))
            zt = self.rt.decode(z0, xt.flatten(start_dim=1), _t)
            rt = self.rt.decode(zt, xt.flatten(start_dim=1), _t)
            reg_loss = rt.square().mean()
        else:
            zt = self.rt.forward(xt.flatten(start_dim=1), _t).rsample()
            reg_loss = zt.square().mean()

        fm_loss = (self.vt(_t, xt, zt) - ut).square().mean()
        
        beta = fm_loss.detach() / (torch.abs(reg_loss.detach()) + 1e-8)
        kl = - (self.prior.log_prob(zt) - self.rt.log_prob(xt.flatten(start_dim=1), zt, _t)).mean()

        beta_kl = fm_loss.detach() / (torch.abs(kl.detach()) + 1e-8)
        return fm_loss # + 0.001*beta_kl*kl # beta * reg_loss * self.alpha + 




class mCNF(nn.Module):
    def __init__(self, features: int, freqs: int = 3, device='cpu', **kwargs):
        super().__init__()
        self.features = features
        self.net = MLP(2 * freqs + features, features, **kwargs)
        self.device=device

        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor) -> Tensor:
        t = self.freqs * t[..., None]
        t = torch.cat((t.cos(), t.sin()), dim=-1)
        t = t.expand(*x.shape[:-1], -1).to(self.device)

        return self.net(torch.cat((t, x), dim=-1))

    def decode(self, x:Tensor, t=1.) -> Tensor:

        xt = odeint_adjoint(
            self,
            x, 
            torch.tensor([t, 0.], device=x.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return xt

    def decode_with_trajectory(self, x: Tensor, t=None, num_points=100) -> tuple:
    
        if t is None:
            t = 1.
        
        # Create time points
        start_time = 0.
        end_time = t
        time_points = torch.linspace(start_time, end_time, num_points, device=x.device)
        
        # Storage for trajectory
        trajectory = []
        recorded_times = []
        
        # Wrapper for ODE function to collect states
        def ode_func_with_recording(t, state):
            # Record the current state
            trajectory.append(state.detach().clone())
            recorded_times.append(t.item())
            # Call the original function
            return self(t, state)
        
        # Run the ODE solver
        xt = odeint_adjoint(
            ode_func_with_recording, 
            x, 
            time_points, 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)
        # Convert lists to tensors
        trajectory_tensor = torch.stack(trajectory)
        recorded_times_tensor = torch.tensor(recorded_times)
        # Return final state and the recorded trajectory
        return xt[-1], trajectory_tensor, recorded_times_tensor

    def log_prob(self, x: Tensor, t=0.) -> Tensor:

        I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)

        def augmented(t: Tensor, state) -> Tensor:
            x, ladj = state
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x)

            jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0]).to(device=x.device)
        xt, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.tensor([t, 1.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)
        # print("priorlog", priorlog.shape)
        return log_normal(xt[-1]) + ladj[-1] * 1e2

    def sample(self, n):
        p = DiagNormal(torch.zeros(self.features).to(self.device), torch.ones(self.features).to(self.device))
        z0 = p.sample(n).to(self.device)
        return self.decode(z0)


class mFlowMatchingLoss(nn.Module):
    def __init__(self, vt: nn.Module, sig_min=1e-4):
        super().__init__()

        self.vt = vt
        self.sig_min = sig_min

    def forward(self, x: Tensor) -> Tensor:
        t = torch.ones_like(x[..., 0, None]) * torch.rand(1)
        x0 = torch.randn_like(x)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x0
        ut = (1 - self.sig_min) * x0 - x

        fm_loss = (self.vt(t.squeeze(-1), xt) - ut).square().mean()
        return fm_loss 



if __name__ == '__main__':

    # data, _ = make_moons(16384, noise=0.05)
    # data = torch.from_numpy(data).float()

    batch_size = 1000
    num_per_class = 100_000
    dat_dir = "data"
    num_classes = 5
    dataloader = DataLoader(PINWHEEL(num_per_class, dat_dir, num_classes=num_classes, gen=True, plot=False), batch_size=batch_size, shuffle=True)
    # z_feature_dim = 1
    log_sample_path = "gen_figs/CFM"
    x_feature_dim = 2
    n_gen = 5000
    n_epoch= 30

    vt = mCNF(x_feature_dim, hidden_features=[120]*5)
    loss = mFlowMatchingLoss(vt)

    optimizer = torch.optim.Adam(
        itertools.chain(
                vt.parameters()
                ), 
        lr=1e-3)

    for epoch in tqdm(range(n_epoch)):
        for i, (x, y) in enumerate(dataloader):
            # emb = y.view(-1,1).to(torch.float32)
            loss(x).backward()

            optimizer.step()
            optimizer.zero_grad()

        # Sampling
        if epoch % 10 == 0:
            with torch.no_grad():
                # z1 = torch.randint(0,5,(16384,1)).to(torch.float32)
                x0 = torch.randn(n_gen, x_feature_dim)
                x1 = vt.decode(x0)

            # plt.figure(dpi=150)
            plt.figure(figsize=(6.5,6))
            # plt.hist2d(*x.T, bins=64)
            plt.scatter(x1[:,0], x1[:,1], marker=".")
            plt.savefig(os.path.join(log_sample_path, 'image_grid_{}.png'.format(epoch)))
            plt.close()

    dataset, test_dataset = get_dataset(5, "data", 2000, 2000)

    test_dataloader = DataLoader(test_dataset, batch_size=500, shuffle=True,
                     num_workers=0, drop_last=True)
    train_dataloader = DataLoader(dataset, batch_size=500, shuffle=True,
                     num_workers=0, drop_last=True)

    w1_loss = []
    # one epoch
    for i, (x,y) in enumerate(test_dataloader):
        n = len(x)
        x0 = torch.randn(n, x_feature_dim)
        xnew = vt.decode(x0).cpu().detach().numpy()
        xorg = x.cpu().detach().numpy()
        # compute w1
        w = 1/n * np.ones(n)

        M = ot.dist(xorg, xnew, "euclidean")
        M /= M.max() * 0.1
        d_emd = ot.emd2(w, w, M)
        w1_loss.append(d_emd)

    print("average w1 (FM)", np.array(w1_loss).mean())

    w1_loss = []
    for i, (x_0,y_0) in enumerate(test_dataloader):
        for j, (x_1,y_0) in enumerate(train_dataloader):
            n = len(x_0)
            assert len(x_0) == len(x_1)
            if i == j:
                # compute w1
                w = 1/n * np.ones(n)

                M = ot.dist(x_0.cpu().detach().numpy(), x_1.cpu().detach().numpy(), "euclidean")
                M /= M.max() * 0.1
                d_emd = ot.emd2(w, w, M)
                w1_loss.append(d_emd)
    print("average w1 (truth)", np.array(w1_loss).mean())


    # llk_net = LLK(x_feature_dim, z_feature_dim, hidden_features=[64] * 3)
    # pos_net = CNF(x_feature_dim, z_feature_dim, hidden_features=[32] * 2)
    # prior = DiagNormal(torch.zeros(z_feature_dim), torch.ones(z_feature_dim))
    # # Training
    # loss = FlowMatchingLoss(llk_net, pos_net, prior)
    # optimizer = torch.optim.Adam(
    #     itertools.chain(
    #             llk_net.parameters(),
    #             pos_net.parameters()
    #             ), 
    #     lr=1e-3)


    # for epoch in tqdm(range(n_epoch)):
    #     for i, (x, y) in enumerate(dataloader):
    #         # emb = y.view(-1,1).to(torch.float32)
    #         loss(x).backward()

    #         optimizer.step()
    #         optimizer.zero_grad()

    # # Sampling
    # with torch.no_grad():
    #     # z1 = torch.randint(0,5,(16384,1)).to(torch.float32)
    #     z0 = prior.sample((n_gen,))
    #     x0 = torch.randn(n_gen, x_feature_dim)
    #     x1 = llk_net.decode(x0, z0)
    #     z1 = pos_net.decode(z0, x1)

    # if z_feature_dim > 1:
    # # get the projection of z onto its first eigenvector direction
    #     z1 = first_eigen_proj(z1)
    # else:
    #     z1 = z1.squeeze()

    # plt.figure(dpi=150)
    # # plt.hist2d(*x.T, bins=64)
    # plt.scatter(x1[:,0], x1[:,1], c=z1, marker=".", cmap=plt.colormaps['gist_rainbow'])
    # plt.colorbar()
    # plt.show()
    # plt.close()

