import torch
import torch.nn as nn
from torch import Tensor
from typing import *
import math
from typing import Callable, Union

from .transformer import Reconstructor, Tokenizer, Transformer

class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

# ---------- Shared bits ----------

class PositionalEmbedding(torch.nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


class MLPDiffusion(nn.Module):
    def __init__(self, d_in, dim_t = 512, use_mlp=True):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        ) if use_mlp else nn.Linear(dim_t, d_in)

        self.map_noise = PositionalEmbedding(num_channels=dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
        
        self.use_mlp = use_mlp
    
    def forward(self, x, timesteps):
        emb = self.map_noise(timesteps)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class Net(nn.Module):
    """
        Input:
            x_num: [bs, d_numerical]
            x_cat: [bs, len(categories)]
        Output:
            x_num_pred: [bs, d_numerical], the predicted mean for numerical data
            x_cat_pred: [bs, sum(categories)], the predicted UNORMALIZED logits for categorical data
        Default parameter source:
            https://github.com/MinkaiXu/TabDiff/blob/main/tabdiff/configs/tabdiff_configs.toml
    num_layers=2, d_token=4,
            n_head = 1, factor = 32, bias = True, dim_t=1024, use_mlp=True, **kwargs
    """
    def __init__(
            self, d_numerical, categories, num_layers=2, d_token=4,
            n_head = 1, factor = 32, bias = True, dim_t=1024, use_mlp=True, **kwargs
        ):
        super().__init__()
        self.d_numerical = d_numerical
        self.categories = categories

        self.tokenizer = Tokenizer(d_numerical, categories, d_token, bias = bias)
        self.encoder = Transformer(num_layers, d_token, n_head, d_token, factor)
        d_in = d_token * (d_numerical + len(categories))
        self.mlp = MLPDiffusion(d_in, dim_t=dim_t, use_mlp=use_mlp)
        self.decoder = Transformer(num_layers, d_token, n_head, d_token, factor)
        self.detokenizer = Reconstructor(d_numerical, categories, d_token)
        
        self.model = nn.ModuleList([self.tokenizer, self.encoder, self.mlp, self.decoder, self.detokenizer])

    def forward(self, timesteps: torch.Tensor, x: torch.Tensor):
        x_num = x[:,:self.d_numerical]
        x_cat = x[:,self.d_numerical:]
        e = self.tokenizer(x_num, x_cat)
        decoder_input = e[:, 1:, :]        # ignore the first CLS token. 
        y = self.encoder(decoder_input)
        pred_y = self.mlp(y.reshape(y.shape[0], -1), timesteps)
        pred_e = self.decoder(pred_y.reshape(*y.shape))
        x_num_pred, x_cat_pred = self.detokenizer(pred_e)
        x_cat_pred = torch.cat(x_cat_pred, dim=-1) if len(x_cat_pred)>0 else torch.zeros_like(x_cat).to(x_num_pred.dtype)
        # print(x_num_pred.size(), x_cat_pred.size())
        return torch.cat([x_num_pred, x_cat_pred], dim=1)

#### network for time function ####

class OT_t(nn.Module):
    '''
    just regular t to ensure it works in gaussian setting
    '''
    def __init__(
        self 
    ) -> None:
        super().__init__()

    def atx(self, t: torch.Tensor) -> torch.Tensor:
        # return torch.ones_like(t.view(-1,1))
        # return 1 / (1. - (0.999 * t))
        return (1 / (1. - (0.999 * t.view(-1,1))))

        # return t.view(-1,1)
        
    def forward(self, t : Tensor) -> Tensor:
        return t.view(-1,1), 1. - (0.999 * t.view(-1,1))

class VPDiffusion_t(nn.Module):
    '''
    Variance preserving diffusion field
    '''
    def __init__(
        self
    ) -> None:
        super().__init__()
        self.beta_min = 0.1
        self.beta_max = 20.0
        self.eps = 1e-5

    def T(self, s: torch.Tensor) -> torch.Tensor:
        return self.beta_min * s + 0.5 * (s ** 2) * (self.beta_max - self.beta_min)
    
    def beta(self, t: torch.Tensor) -> torch.Tensor:
        return self.beta_min + t*(self.beta_max - self.beta_min)
    
    def alpha(self, t: torch.Tensor) -> torch.Tensor:
        return torch.exp(-0.5 * self.T(t))
    
    def sigma_t(self, t: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor:
        return torch.sqrt(1. - self.alpha(1. - t) ** 2)

    def atx(self, t: torch.Tensor) -> torch.Tensor:
        num = - torch.exp(-0.5 * self.T(1.-t))
        denum = 1. - torch.exp(- self.T(1. - t))
        return - 0.5 * self.beta(1. - t) * (num/denum)
        
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        return self.alpha(1. - t), torch.sqrt(1. - self.alpha(1. - t) ** 2)

class VEDiffusion_t(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.sigma_min = 0.01
        self.sigma_max = 2.
        self.eps = 1e-5


    def sigma_t(self, t: torch.Tensor) -> torch.Tensor:
    
        return self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    
    def dsigma_dt(self, t: torch.Tensor) -> torch.Tensor:
    
        return self.sigma_t(t) * torch.log(torch.tensor(self.sigma_max/self.sigma_min))
    
    def atx(self, t: torch.Tensor) -> torch.Tensor:
        return (self.dsigma_dt(1. - t) / self.sigma_t(1. - t))
    
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        return torch.ones_like(t), self.sigma_t(1. - t)
    
class LogitNormal_t(nn.Module):
    '''
    Logit normal distribution-based time function
    See FLUX.1 paper for details: https://arxiv.org/pdf/2506.15742
    I put 0.999 for stability purposes.
    '''
    def __init__(
        self 
    ) -> None:
        super().__init__()
        self.alpha = 3.0
        self.sigma = 1.0

    def atx(self, t):
        t = t.clamp(min=1e-7)
        
        denum_add = (1 / (0.999 * t.view(-1,1)) - 1)**self.sigma
        t_prime = self.alpha / (self.alpha + denum_add) 

        num = self.sigma * self.alpha * (1/(0.999 * t.view(-1,1)) - 1)**(self.sigma - 1)
        den = 0.999 * t**2 * (self.alpha + (1/(0.999 * t.view(-1,1)) - 1)**self.sigma)**2
        # num = self.sigma * self.alpha * (1/(0.999 * t.view(-1,1)) - 1)**(self.sigma - 1)
        t_prime_dot = num/den

        return t_prime_dot / (1 - t_prime)
        
    def forward(self, t : Tensor) -> Tensor:
        t = t.clamp(min=1e-7)
        denum_add = (1 / (0.999 * t.view(-1,1)) - 1)**self.sigma
        # denum_add = (1 / t.view(-1,1) - 1)**self.sigma
        t_prime = self.alpha / (self.alpha + denum_add)  
        return t_prime, 1. - t_prime

class Cosine_t(nn.Module):
    '''
    cosine time function
    See stochastic interpolant paper for details: https://arxiv.org/pdf/2209.15571
    '''
    def __init__(
        self 
    ) -> None:
        super().__init__()

    def atx(self, t) -> Tensor:
        t = t.clamp(max=1 - 1e-5)
        alpha, beta = torch.sin(0.5 * math.pi * t), torch.cos(0.5 * math.pi * t)
        dalpha = torch.cos(0.5 * math.pi * t) * 0.5 * math.pi
        dbeta = -torch.sin(0.5 * math.pi * t) * 0.5 * math.pi
        # (alpha, beta), (dalpha, dbeta) = t_dir(self, t)
        return dalpha - (dbeta/beta) * alpha
        
    def forward(self, t : Tensor) -> Tensor:
        
        return torch.sin(0.5 * math.pi * t), torch.cos(0.5 * math.pi * t)