import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import math
# from zuko.utils import odeint
from torchdiffeq import odeint_adjoint as odeint
from torch.utils.data import TensorDataset
from typing import *

def jvp(f: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.autograd.functional.jvp(
        f, x, v, 
        create_graph=torch.is_grad_enabled()
    )

def t_dir(f: Callable[[torch.Tensor], torch.Tensor ], t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    return jvp(f, t, torch.ones_like(t))

def get_t_dir(model: nn.Module, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    def f(t_in):
        def f_(t_in):
            return model(t_in)
        return f_

    return t_dir(f(t), t)

class TabFlowMatching: ## note: data should contain both continuous and categorical variables

    def __init__(self, d_cont, cat_list) -> None:
        super().__init__()
        self.eps = 1e-5
        self.cat_list = cat_list
        self.d_cont = d_cont
        self.loss_fn = torch.nn.CrossEntropyLoss()
    
    def psi_t(self, x: torch.Tensor, x_1: torch.Tensor, alpha: torch.Tensor, beta: torch) -> torch.Tensor:
        """ Conditional Flow
        """
        return alpha * x_1 + beta * x ## for continuous only until here
        
    def loss(self, v_t: nn.Module,
           x_1: torch.Tensor) -> torch.Tensor:
        """ Compute loss
        """
        # t = (torch.rand(1, device=x_1.device) + torch.arange(x_1.shape[0], device=x_1.device) / x_1.shape[0]) % (1 - self.eps)
        # t = t.view(-1, 1)
        t = torch.rand((x_1.shape[0], 1), device=x_1.device) % (1 - self.eps)
        # t = t[:, None].expand(x_1.shape)
        # print(t.size())
        x_0 = torch.randn_like(x_1)
        # (alpha, beta), (dalpha, dbeta) = t_dir(v_t.net_t, 0.999 * t)
        alpha, beta = v_t.net_t(t)
        atx = v_t.net_t.atx(t).view(-1,1)

        alpha = alpha.expand(x_1.shape)
        beta = beta.expand(x_1.shape)
        atx = atx.expand(x_1.shape)
        
        x_t= self.psi_t(x_0, x_1, alpha, beta)
        
        theta_all = v_t(t[:,0], x_t)
        
        loss_q = torch.tensor(.0, device=x_1.device)
        if self.d_cont > 0:
            # normal_dist = Normal(theta_all[:,:self.d_cont], 1/(torch.sqrt(2*atx[:,:self.d_cont])))
            # loss_q += torch.mean(-normal_dist.log_prob(x_1[:,:self.d_cont]))
            loss_q = torch.mean((x_1[:,:self.d_cont]-theta_all[:,:self.d_cont])**2)
        if self.cat_list is not None:
            a = 0
            a += self.d_cont
            for i in self.cat_list:
                b = a+i
                loss_q += self.loss_fn(theta_all[:,a:b],
                                        torch.argmax(x_1[:,a:b], dim=-1))
                a += i
        
        return loss_q


### the conditional vector fields baseline
class CondVF(nn.Module):
    def __init__(self, net: nn.Module, net_t: nn.Module, n_steps: int = 100, d_cont: int = None, cat_list: list = None) -> None:
        super().__init__()
        self.net = net
        self.net_t = net_t
        self.atol=1e-06
        self.rtol=1e-05
        self.useodeint = False
        self.n_steps = n_steps
        self.d_cont = d_cont
        self.cat_list = cat_list

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        # print(t.size(), x.size())
        if self.useodeint:
            return self.forward_for_ode(t, x)
        else:
            return self.net(t, x)
    
    def forward_for_ode(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t = t.expand(x.size(0))
        x1 = self.net(t, x)
        a = 0
        if self.d_cont > 0: a += self.d_cont
        if self.cat_list is not None:
            for i in self.cat_list:
                b = a+i
                x1[:,a:b] = F.softmax(x1[:,a:b], dim=-1)
                a += i
        t = t[:, None].expand((x.size(0),1))
    
        (alpha, beta), (dalpha, dbeta) = t_dir(self.net_t, t)
        alpha = alpha.expand(x1.shape)
        beta = beta.expand(x1.shape)
        dalpha = dalpha.expand(x1.shape)
        dbeta = dbeta.expand(x1.shape)

        p1 = dalpha - (dbeta/beta) * alpha
        p2 = dbeta/beta
        return p1 * x1 + p2 * x
  
    def wrapper(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t = t * torch.ones(len(x), device=x.device)
        return self(t, x)
    
    # uses from torchdiffeq import odeint_adjoint as odeint
    def decode_t0_t1(self, x_0, t0, t1):
        self.useodeint = True
        x_1 = odeint(self, y0=x_0, t=torch.tensor([t0, t1],device=x_0.device), 
                      method='euler', options={'step_size': (1/self.n_steps)})[-1]
        self.useodeint = False
        return x_1
    
    def encode(self, x_1: torch.Tensor) -> torch.Tensor:
        self.useodeint = True
        x_0 = odeint(self, y0=x_1, t=torch.tensor([1., 0.],device=x_1.device), 
                      method='euler', options={'step_size': (1/self.n_steps)})[-1]
        self.useodeint = False
        return x_0
    
    def decode(self, x_0: torch.Tensor) -> torch.Tensor:
        self.useodeint = True
        x_1 = odeint(self, y0=x_0, t=torch.tensor([0., 1.],device=x_0.device), 
                      method='euler', options={'step_size': (1/self.n_steps)})[-1]
        self.useodeint = False
        return x_1
    
    def decode_manual(self, x_0: torch.Tensor, n_steps: int = 100) -> torch.Tensor:
        x_1 = torch.zeros_like(x_0)
        tt = torch.tensor(0., device=x_0.device)
        h = 1/n_steps
        for _ in range(self.n_steps):
            x_1 += h * self.forward_for_ode(tt, x_1)
            tt += h
        return x_1
