import math
import matplotlib.pyplot as plt
from functools import partial
import itertools
import numpy as np
from tqdm import tqdm
from typing import *
from pylab import cm

import torch
from torch import Tensor, vmap
from torch.func import grad_and_value, jacrev, vmap
import torch.nn as nn
from torch.nn.functional import leaky_relu, sigmoid, softmax
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
import torch.nn.utils as nn_utils
from torch.distributions import Dirichlet, Categorical, Normal, Uniform
from torchdiffeq import odeint_adjoint

from zuko.distributions import DiagNormal
from unet import *
from sfa_lds import sGRU


torch.set_printoptions(precision=3)
torch.set_default_dtype(torch.float64)

class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: List[int] = [64, 64],
        fct=nn.Tanh(),
        batch_norm=False,
        dropout=False,
        weight_norm=False,
        layer_norm=False,
        p=0.2
        # fct=ScaledSigmoid()
    ):
        layers = []

        for a, b in zip(
            (in_features, *hidden_features),
            (*hidden_features, out_features),
        ):  
            linear_layer = nn.Linear(a, b)
            if weight_norm:
                linear_layer = nn_utils.weight_norm(linear_layer)
            if batch_norm:
                layers.extend([linear_layer, nn.BatchNorm1d(b), fct])
            elif layer_norm:
                layers.extend([linear_layer, nn.LayerNorm(b), fct])
            elif dropout:
                layers.extend([linear_layer, nn.Dropout(p=p), fct])
            else:
                layers.extend([linear_layer, fct])

        if not weight_norm or batch_norm or layer_norm or dropout:
            super().__init__(*layers[:-1])
        else:
            super().__init__(*layers[:-2])


class Encoder(nn.Module):
    def __init__(self, x_dim, z_dim, **kwargs):
        super().__init__()
        self.net = MLP(x_dim, z_dim, **kwargs)

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, x_dim, z_dim, **kwargs):
        super().__init__()
        self.net = MLP(z_dim, x_dim, **kwargs)

    def forward(self, z):
        return self.net(z)


class MLPEncoder(nn.Module):
    def __init__(self, x_dim, z_dim, in_ch=1, **kwargs):
        super().__init__()
        self.net = MLP(x_dim**2*in_ch, z_dim, **kwargs)

    def forward(self, x):
        return self.net(x.flatten(start_dim=1))

class MLPDecoder(nn.Module):
    def __init__(self, x_dim, z_dim, in_ch=1, **kwargs):
        super().__init__()
        self.net = MLP(z_dim, x_dim**2*in_ch, **kwargs)
        self.in_ch = in_ch
        self.x_dim = x_dim

    def forward(self, z):
        return self.net(z).reshape(-1, self.in_ch, self.x_dim)


class CNNEncoder(nn.Module):
    def __init__(self, in_ch: int, z_dim: int, fct):
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_ch, 32, 4, stride=2, padding=1),   # -> (B, 32, H/2, W/2)
            fct,
            nn.Conv2d(32, 64, 4, stride=2, padding=1),      # -> (B, 64, H/4, W/4)
            fct,
            nn.Conv2d(64, 128, 3, stride=2, padding=1),     # -> (B, 128, H/8, W/8)
            fct
        )
        self.fc = nn.Linear(128 * 4 * 4, z_dim)  # assumes input is 32x32

    def forward(self, x):  # x: (B, in_ch, 32, 32)
        h = self.conv_layers(x)
        h = h.view(h.size(0), -1)
        z = self.fc(h)
        return z


class CNNDecoder(nn.Module):
    def __init__(self, out_ch: int, z_dim: int, fct, target_size=32):
        super().__init__()
        self.target_size = target_size
        self.fc = nn.Linear(z_dim, 128 * 4 * 4)
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # -> (B, 64, 8, 8)
            fct,
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),   # -> (B, 32, 16, 16)
            fct,
            # nn.ConvTranspose2d(32, out_ch, 4, stride=2, padding=1),  # -> (B, out_ch, 32, 32)
            # nn.Tanh()  # or Sigmoid depending on your preprocessing
        )
        # Final conv after interpolation to target size
        self.final_conv = nn.Conv2d(32, out_ch, kernel_size=3, padding=1)

    def forward(self, z):  # z: (B, z_dim)
        h = self.fc(z)
        h = h.view(h.size(0), 128, 4, 4)
        h = self.deconv_layers(h)
        h = F.interpolate(h, size=(self.target_size, self.target_size), mode='bilinear', align_corners=False)
        out = self.final_conv(h)
        return out


class LatentCNF(nn.Module):
    def __init__(self, z_dim, freqs=2, **kwargs):
        super().__init__()

        # self.prior = prior
        self.fc = MLP(z_dim+2*freqs, z_dim, **kwargs)
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t, z):
        # z: (batch, z_dim), t: scalar or tensor of shape (batch, 1)
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*z.shape[:-1], -1)

        return self.fc(torch.cat([z, temb], dim=1))

    def decode(self, z, t=None):
        if t is None:
            t = 1.
        zt = odeint_adjoint(
            self, z, torch.tensor([t, 0.]), 
            adjoint_params=itertools.chain(self.parameters()),
            method="dopri5",
            atol=1e-8, rtol=1e-8)[-1]

        return zt


class FlowMatchingLoss(nn.Module):
    def __init__(self, vt, encoder, prior, sig_min=1e-4):
        super().__init__()
        self.vt = vt
        self.encoder = encoder
        self.prior = prior
        self.sig_min = sig_min

    def forward(self, x):
        z = self.encoder(x)

        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(z[..., 0, None]) * _t

        z1 = self.prior.sample((len(z),)).to(z.device)
        zt = (1-t) * z + (self.sig_min + (1 - self.sig_min) * t) * z1

        ut = (1 - self.sig_min) * z1 - z

        fm_loss = (self.vt(_t, zt) - ut).square().mean(-1).mean()
        return fm_loss


class ReconstructionLoss(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        loss = (x - self.decoder(self.encoder(x))).square().mean()
        return loss


# --------------------------------- LDS -------------------------------------
class LatentCNF_LDS(nn.Module):
    def __init__(self, z_dim, S, F, freqs=2, num_hidden_z=16, num_layer=1, dsemb=0, **kwargs):
        super().__init__()
        self.S = S
        self.F = F
        self.num_hidden_z = num_hidden_z
        self.num_layers = num_layer

        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
        self.rnnz = sGRU(z_dim, self.num_hidden_z, self.num_layers, S=self.S, F=self.F, dsemb=dsemb, bidirectional=False)

        # self.prior = prior
        self.fc = MLP(z_dim+num_hidden_z*num_layer+2*freqs+dsemb, z_dim, **kwargs)
        

    def _forward(self, t, z, zcemb, s):
        # z: (batch, z_dim), t: scalar or tensor of shape (batch, 1)
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*z.shape[:-1], -1)

        s = self.freqs_s * s[..., None]
        # print("s", s.shape)
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        semb = semb.expand(*z.shape[:-1], -1)

        return self.fc(torch.cat([z, zcemb, temb, semb], dim=1))

    def forward(self, zS, t=None, indices=None):
        if t is None:
            t = 0.
        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S
        out = self._forward(t, zS[0], None, indices[0])
        zSemb = [None]
        for s in range(1,self.S):
            zcemb = self.rnnz.recurse(zS[s-1], zSemb[-1], indices[s])
            zSemb.append(zcemb)
            out += self._forward(t, zS[s], zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), indices[s])
        return out


    def _decode(self, z, zcemb, s, t=None):
        if t is None:
            t = 1.
        if zcemb is not None:
            zcemb = zcemb.clone().detach().requires_grad_(True)
        zt = odeint_adjoint(
            partial(self._forward, zcemb=zcemb, s=s), 
            z, torch.tensor([t, 0.]), 
            adjoint_params=itertools.chain(self.parameters()),
            method="dopri5",
            atol=1e-8, rtol=1e-8)[-1]

        return zt

    def decode(self, zS, t=None, indices=None):
        if t is None:
            t = 0.
        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S
        z0S = zS.clone().detach()
        # embed xS
        zSemb = [None]

        z0S[0] = self.decode(zS[0], zSemb[-1], indices[0], t)
        # z0S[0] = z0S0 / torch.norm(z0S0, dim=-1, keepdim=True)
        for s in range(1,self.S):
            # update hidden state and append
            zcemb = self.rnnz.recurse(z0S[s-1], zSemb[-1], indices[s])
            # print(zcemb.shape)
            zSemb.append(zcemb)
            z0S[s] = self._decode(zS[s], zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), indices[s], t)
            # z0S[s] = z0Ss / torch.norm(z0Ss, dim=-1, keepdim=True)
        return z0S


class FlowMatchingLoss_LDS(nn.Module):
    def __init__(self, vt, encoder, prior, sig_min=1e-4):
        super().__init__()
        self.vt = vt
        self.encoder = encoder
        self.prior = prior
        self.sig_min = sig_min

    def forward(self, x, indices=None):
        S = x.shape[0]
        z = self.encoder(x)

        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(z[..., 0, None]) * _t

        z1 = self.prior.sample((len(z),)).to(z.device)
        zt = (1-t) * z + (self.sig_min + (1 - self.sig_min) * t) * z1

        ut = ((1 - self.sig_min) * z1 - z).sum(0)

        fm_loss = (self.vt(_t, zt, indices) - ut).square().mean(-1).mean()/S
        return fm_loss


class ReconstructionLoss_LDS(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        # reshape to (S*n, d)
        S, B = x.shape[0]
        remaining_dims = x.shape[2:]

        x = x.view(S*B, *remaining_dims)

        loss = (x - self.decoder(self.encoder(x))).square().mean()
        return loss

