import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, MultivariateNormal
import math
# from zuko.utils import odeint
from torchdiffeq import odeint_adjoint as odeint
from torch.utils.data import TensorDataset
from typing import *

def jvp(f: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.autograd.functional.jvp(
        f, x, v, 
        create_graph=torch.is_grad_enabled()
    )

def t_dir(f: Callable[[torch.Tensor], torch.Tensor ], t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    return jvp(f, t, torch.ones_like(t))

def get_t_dir(model: nn.Module, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    def f(t_in):
        def f_(t_in):
            return model(t_in)
        return f_

    return t_dir(f(t), t)

class TabFlowMatching: ## note: data should contain both continuous and categorical variables

    def __init__(self, d_cont, cat_list) -> None:
        super().__init__()
        self.eps = 1e-5
        self.cat_list = cat_list
        self.d_cont = d_cont
        self.loss_fn = torch.nn.CrossEntropyLoss()
    
    def psi_t(self, x: torch.Tensor, x_1: torch.Tensor, alpha: torch.Tensor, beta: torch) -> torch.Tensor:
        """ Conditional Flow
        """
        return alpha * x_1 + beta * x ## for continuous only until here
        
    def loss(self, v_t: nn.Module,
           x_1: torch.Tensor) -> torch.Tensor:
        """ Compute loss
        """
        # t = (torch.rand(1, device=x_1.device) + torch.arange(x_1.shape[0], device=x_1.device) / x_1.shape[0]) % (1 - self.eps)
        # t = t.view(-1, 1)
        t = torch.rand((x_1.shape[0], 1), device=x_1.device) % (1 - self.eps)
        # t = t[:, None].expand(x_1.shape)
        # print(t.size())
        x_0 = torch.randn_like(x_1)
        # (alpha, beta), (dalpha, dbeta) = t_dir(v_t.net_t, 0.999 * t)
        alpha, beta = v_t.net_t(t)
        atx = v_t.net_t.atx(t).view(-1,1)

        alpha = alpha.expand(x_1.shape)
        beta = beta.expand(x_1.shape)
        atx = atx.expand(x_1.shape)
        
        x_t= self.psi_t(x_0, x_1, alpha, beta)
        
        theta_all = v_t(t[:,0], x_t)
        
        loss_q = torch.tensor(.0, device=x_1.device)
        if self.d_cont > 0:
            # normal_dist = Normal(theta_all[:,:self.d_cont], 1/(math.sqrt(2)*atx[:,:self.d_cont]))
            # loss_q += torch.mean(-normal_dist.log_prob(x_1[:,:self.d_cont]).sum(-1))

            identity = torch.eye(self.d_cont, device=x_1.device, dtype=torch.float).unsqueeze(0).expand(x_1.shape[0], -1, -1)
            scale = 1 - (1 - 0.01) * t.unsqueeze(1) ** 2
            sigma = scale * identity
            dist = MultivariateNormal(theta_all[:,:self.d_cont], sigma)
            loss_q += -dist.log_prob(x_1[:,:self.d_cont]).mean()

            # normal_dist = Normal(theta_all[:,:self.d_cont], 1/(torch.sqrt(2*atx[:,:self.d_cont])))
            # loss_q += torch.mean(-normal_dist.log_prob(x_1[:,:self.d_cont]))
        if self.cat_list is not None:
            a = 0
            a += self.d_cont
            for i in self.cat_list:
                b = a+i
                # loss_q += self.loss_fn(theta_all[:,a:b],
                #                         torch.argmax(x_1[:,a:b], dim=-1))
                dist = torch.distributions.Categorical(logits=theta_all[:,a:b])
                loss_q += -dist.log_prob(torch.argmax(x_1[:,a:b], dim=-1)).mean()
            
                a += i
        
        return loss_q

### the conditional vector fields baseline
class CondVF(nn.Module):
    def __init__(self, net: nn.Module, net_t: nn.Module, n_steps: int = 100, d_cont: int = None, cat_list: list = None) -> None:
        super().__init__()
        self.net = net
        self.net_t = net_t
        self.atol=1e-06
        self.rtol=1e-05
        self.useodeint = False
        self.n_steps = n_steps
        self.d_cont = d_cont
        self.cat_list = cat_list

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        # print(t.size(), x.size())
        if self.useodeint:
            return self.forward_for_ode(t, x)
        else:
            return self.net(t, x)
    
    def forward_for_ode(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t = t.expand(x.size(0))
        x1 = self.net(t, x)
        a = 0
        if self.d_cont > 0: a += self.d_cont
        if self.cat_list is not None:
            for i in self.cat_list:
                b = a+i
                x1[:,a:b] = F.softmax(x1[:,a:b], dim=-1)
                a += i
        t = t[:, None].expand((x.size(0),1))
    
        (alpha, beta), (dalpha, dbeta) = t_dir(self.net_t, t)
        alpha = alpha.expand(x1.shape)
        beta = beta.expand(x1.shape)
        dalpha = dalpha.expand(x1.shape)
        dbeta = dbeta.expand(x1.shape)

        p1 = dalpha - (dbeta/beta) * alpha
        p2 = dbeta/beta
        return p1 * x1 + p2 * x
  
    def wrapper(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t = t * torch.ones(len(x), device=x.device)
        return self(t, x)
    
    # uses from torchdiffeq import odeint_adjoint as odeint
    def decode_t0_t1(self, x_0, t0, t1, method = 'euler'):
        self.useodeint = True
        # Default: step_size for fixed-step, rtol/atol for adaptive
        
        if method in ['euler', 'rk4', 'midpoint']:
            options = {'step_size': 1 / self.n_steps}
        # Untuk solver adaptive (dopri5, adams, dll) bisa set rtol/atol kalau perlu:
        elif method in ['dopri5', 'adaptive_heun', 'explicit_adams', 'implicit_adams']:
            options = {}
            # 'rtol': 1e-5, 'atol': 1e-6

        x_1 = odeint(
            self, 
            y0=x_0, 
            t=torch.tensor([t0, t1], device=x_0.device),
            method=method, 
            options=options,
            rtol=1e-5, atol=1e-6
        )[-1]
        
        # x_1 = odeint(self, y0=x_0, t=torch.tensor([t0, t1],device=x_0.device), 
        #               method='euler', options={'step_size': (1/self.n_steps)})[-1]
        self.useodeint = False
        return x_1
    
    def encode(self, x_1: torch.Tensor) -> torch.Tensor:
        self.useodeint = True
        x_0 = odeint(self, y0=x_1, t=torch.tensor([1., 0.],device=x_1.device), 
                      method='euler', options={'step_size': (1/self.n_steps)})[-1]
        self.useodeint = False
        return x_0
    
    def decode(self, x_0: torch.Tensor) -> torch.Tensor:
        self.useodeint = True
        x_1 = odeint(self, y0=x_0, t=torch.tensor([0., 1.],device=x_0.device), 
                      method='euler', options={'step_size': (1/self.n_steps)})[-1]
        self.useodeint = False
        return x_1
    
    def decode_manual(self, x_0: torch.Tensor, n_steps: int = 100) -> torch.Tensor:
        x_1 = torch.zeros_like(x_0)
        tt = torch.tensor(0., device=x_0.device)
        h = 1/n_steps
        for _ in range(self.n_steps):
            x_1 += h * self.forward_for_ode(tt, x_1)
            tt += h
        return x_1

### the conditional vector fields using SDE
### documentation can be seen in https://github.com/google-research/torchsde/blob/master/DOCUMENTATION.md
class StochasticCondVF(nn.Module):
    def __init__(self, net: nn.Module, net_t: nn.Module, net_sigma_t: nn.Module, sigma_max: float = 1.,
                 n_steps: int = 100, d_cont: int = None, cat_list: list = None) -> None:
        super().__init__()
        self.net = net
        self.net_t = net_t
        self.net_sigma_t = net_sigma_t
        self.atol=1e-06
        self.rtol=1e-05
        self.useodeint = False
        self.n_steps = n_steps
        self.d_cont = d_cont
        self.cat_list = cat_list
        self.sigma_max = sigma_max ### equal to multiplier of the sigma_t

        self.noise_type = 'diagonal'
        self.sde_type = 'ito'

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        return self.net(t, x)
    
    def f(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        ### Drift term for SDE: f(x,t) + 0.5 * sigma(t)^2 \nabla_x log p(x_t)
        # print(t)
        t = t.expand(x.size(0))
        # print(t.size())
        x1 = self.net(t, x) ## x1 = theta_t(x_t) in the equation
        a = 0
        if self.d_cont > 0: a += self.d_cont
        if self.cat_list is not None:
            for i in self.cat_list:
                b = a+i
                x1[:,a:b] = F.softmax(x1[:,a:b], dim=-1)
                a += i
        t = t[:, None].expand((x.size(0),1))
    
        (alpha, beta), (dalpha, dbeta) = t_dir(self.net_t, t)
        alpha = alpha.expand(x1.shape)
        beta = beta.expand(x1.shape)
        dalpha = dalpha.expand(x1.shape)
        dbeta = dbeta.expand(x1.shape)
        
        # '''since torchsde cannot overcome the Inf problem when we divide by beta, 
        #    we need to clip beta manually'''

        p1 = dalpha - (dbeta/beta) * alpha
        p2 = dbeta/beta
        vtx = p1 * x1 + p2 * x
        '''
        The additional part for SDE, sigma and score are according to
        Example 14 (pp. 20), Summary 17 (pp. 22)
        https://diffusion.csail.mit.edu/docs/lecture-notes.pdf
        (Section 3.1 pp. 14) Remember that in the lecture, z~p_data.
        So, z is the data or x_1.
        '''
        _, sigma_t = self.net_sigma_t(t)
        score = - (x - alpha * x1) / beta**2 ## We will try this instead
        gt = self.sigma_max * sigma_t

        ### for sigma_max = 0 (ODE), no need to clip
        return vtx + 0.5 * gt**2 * score

    def g(self, t: torch.Tensor, x: torch.Tensor):
        ### Diffusion term for SDE: \sigma(t)dWt
        _, sigma_t = self.net_sigma_t(t)
        # print(sigma_t)
        ### for sigma_max = 0 (ODE), no need to clip
        gt = self.sigma_max * sigma_t
        return gt.expand(x.size())