import math
from typing import Callable, Union

import torch
import torch.nn as nn
import torch.optim
from einops import rearrange

from experiments.models.tabbyflow.ef_vfm.modules.transformer import Reconstructor, Tokenizer, Transformer

ModuleType = Union[str, Callable[..., nn.Module]]


class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class PositionalEmbedding(torch.nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


class TimeStepEmbedding(nn.Module):
    """
    Layer that embeds diffusion timesteps.

     Args:
        - dim (int): the dimension of the output.
        - max_period (int): controls the minimum frequency of the embeddings.
        - n_layers (int): number of dense layers
        - fourer (bool): whether to use random fourier features as embeddings
    """

    def __init__(
        self,
        dim: int,
        max_period: int = 10000,
        n_layers: int = 2,
        fourier: bool = False,
        scale=16,
    ):
        super().__init__()
        self.dim = dim
        self.max_period = max_period
        self.n_layers = n_layers
        self.fourier = fourier

        if dim % 2 != 0:
            raise ValueError(f"embedding dim must be even, got {dim}")

        if fourier:
            self.register_buffer("freqs", torch.randn(dim // 2) * scale)

        layers = []
        for i in range(n_layers - 1):
            layers.append(nn.Linear(dim, dim))
            layers.append(nn.SiLU())
        self.fc = nn.Sequential(*layers, nn.Linear(dim, dim))

    def forward(self, timesteps):
        if not self.fourier:
            d, T = self.dim, self.max_period
            mid = d // 2
            fs = torch.exp(-math.log(T) / mid * torch.arange(mid, dtype=torch.float32))
            fs = fs.to(timesteps.device)
            args = timesteps[:, None].float() * fs[None]
            emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        else:
            x = timesteps.ger((2 * torch.pi * self.freqs).to(timesteps.dtype))
            emb = torch.cat([x.cos(), x.sin()], dim=1)

        return self.fc(emb)


class MLP(nn.Module):
    def __init__(self, d_in, dim_t=512, use_mlp=True):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = (
            nn.Sequential(
                nn.Linear(dim_t, dim_t * 2),
                nn.SiLU(),
                nn.Linear(dim_t * 2, dim_t * 2),
                nn.SiLU(),
                nn.Linear(dim_t * 2, dim_t),
                nn.SiLU(),
                nn.Linear(dim_t, d_in),
            )
            if use_mlp
            else nn.Linear(dim_t, d_in)
        )

        self.map_noise = PositionalEmbedding(num_channels=dim_t)
        self.time_embed = nn.Sequential(nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t))

        self.use_mlp = use_mlp

    def forward(self, x, timesteps):
        emb = self.map_noise(timesteps)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape)  # swap sin/cos
        emb = self.time_embed(emb)

        x = self.proj(x) + emb
        return self.mlp(x)


class MLPDiffusion(nn.Module):
    def __init__(self, d_in, n_layers, n_units, emb_dim):
        super().__init__()

        self.proj = nn.Linear(d_in, emb_dim)

        in_dims = [emb_dim] + (n_layers - 1) * [n_units]
        out_dims = n_layers * [n_units]
        layers = nn.ModuleList()
        for i in range(len(in_dims)):
            layers.append(nn.Linear(in_dims[i], out_dims[i]))
            layers.append(nn.ReLU())
        # add final layer
        layers.append(nn.Linear(n_units, d_in))
        self.mlp = nn.Sequential(*layers)
        self.time_emb = TimeStepEmbedding(emb_dim)

    def forward(self, x, timesteps):
        t_emb = self.time_emb(timesteps)
        x = self.proj(x) + t_emb
        return self.mlp(x)


class UniModMLP(nn.Module):
    """
    Input:
        x_num: [bs, d_numerical]
        x_cat: [bs, len(categories)]
    Output:
        x_num_pred: [bs, d_numerical], the predicted mean for numerical data
        x_cat_pred: [bs, sum(categories)], the predicted UNORMALIZED logits for categorical data
    """

    def __init__(
        self,
        d_numerical,
        categories,
        num_layers,
        n_units,
        d_token,
        n_head=1,
        factor=4,
        bias=True,
        dim_t=512,
        use_mlp=True,
        **kwargs,
    ):
        super().__init__()
        self.d_numerical = d_numerical
        self.categories = categories

        self.tokenizer = Tokenizer(d_numerical, categories, d_token, bias=bias)
        # self.encoder = Transformer(num_layers, d_token, n_head, d_token, factor)
        d_in = d_token * (d_numerical + len(categories))
        self.d_token = d_token
        # self.mlp = MLP(d_in, dim_t=dim_t, use_mlp=use_mlp)
        # self.decoder = Transformer(num_layers, d_token, n_head, d_token, factor)
        self.mlp = MLPDiffusion(d_in, num_layers, n_units, dim_t)
        self.detokenizer = Reconstructor(d_numerical, categories, d_token)

        # self.model = nn.ModuleList([self.tokenizer, self.encoder, self.mlp, self.decoder, self.detokenizer])
        self.model = nn.ModuleList([self.tokenizer, self.mlp, self.detokenizer])

    def forward(self, x_num, x_cat, timesteps):
        e = self.tokenizer(x_num, x_cat)
        mlp_input = e[:, 1:, :]  # ignore the first CLS token.
        # y = self.encoder(decoder_input)
        # pred_y = self.mlp(y.reshape(y.shape[0], -1), timesteps)
        # pred_e = self.decoder(pred_y.reshape(*y.shape))

        y = self.mlp(mlp_input.flatten(1), timesteps)
        pred_e = rearrange(y, "b (n d) -> b n d", d=self.d_token)

        x_num_pred, x_cat_pred = self.detokenizer(pred_e)
        x_cat_pred = (
            torch.cat(x_cat_pred, dim=-1) if len(x_cat_pred) > 0 else torch.zeros_like(x_cat).to(x_num_pred.dtype)
        )

        return x_num_pred, x_cat_pred
