import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from experiments.models.highres.layers import (
    CatEmbedding,
    TimeStepEmbedding,
    Timewarp,
    Timewarp_Logistic,
    WeightNetwork,
)
from experiments.models.highres.utils import low_discrepancy_sampler, set_seeds


class HighResMLP(nn.Module):
    def __init__(self, num_features, x_low_dim, emb_dim, n_layers, n_units, act="relu"):
        super().__init__()

        self.num_features = num_features
        self.time_emb = TimeStepEmbedding(emb_dim)
        self.proj = nn.Linear(self.num_features, emb_dim)
        # self.proj_low = nn.Linear(x_low_dim, emb_dim)
        self.proj_low = nn.Sequential(
            nn.Linear(x_low_dim, 2 * emb_dim),
            nn.SiLU(),
            nn.Linear(2 * emb_dim, emb_dim),
        )

        in_dims = [emb_dim] + (n_layers - 1) * [n_units]
        out_dims = (n_layers - 1) * [n_units] + [emb_dim]
        layers = nn.ModuleList()
        for i in range(len(in_dims)):
            layers.append(nn.Linear(in_dims[i], out_dims[i]))
            layers.append(nn.ReLU() if act == "relu" else nn.SiLU())

        self.mlp = nn.Sequential(*layers)
        self.final_layer = nn.Linear(out_dims[-1], self.num_features)

    def forward(self, x_t, x_low, t):
        c_noise = torch.log(t + 1e-8) * 0.25 * 1000  # from CDTD / EDM
        t_emb = self.time_emb(c_noise)
        x_low_cond_1 = self.proj_low(x_low)
        x = self.proj(x_t) + t_emb + x_low_cond_1
        x = self.mlp(x)
        return self.final_layer(x)


class HighResCDTDModel(nn.Module):
    """
    This implements the numerical part of the CDTD.
    """

    def __init__(
        self,
        group_means,
        group_stds,
        n_classes_cat,
        n_classes_num,
        emb_dim,
        n_layers,
        n_units,
        cat_emb_dim,
        sigma_min=1e-5,
        sigma_max=100,
        sigma_data=1,
        weight_low_noise=1.0,
        timewarp_variant="logistic",
    ):
        super().__init__()
        self.num_features = len(group_means)
        self.sigma_data = sigma_data
        categories = n_classes_cat + n_classes_num
        self.emb = CatEmbedding(cat_emb_dim, categories, cat_emb_init_sigma=1)

        self.weightnet = WeightNetwork(1024)

        # init embeddings that allow for efficient retrieval of group moments
        n_groups = torch.tensor([len(m) for m in group_means])
        group_offset = n_groups.cumsum(dim=-1)[:-1]
        group_offset = torch.cat((torch.zeros((1,), dtype=torch.long), group_offset))
        self.register_buffer("group_offset", group_offset)
        self.get_group_means = nn.Embedding.from_pretrained(torch.cat(group_means).unsqueeze(-1), freeze=True)
        self.get_group_stds = nn.Embedding.from_pretrained(torch.cat(group_stds).unsqueeze(-1), freeze=True)

        self.timewarp_variant = timewarp_variant
        if timewarp_variant == "logistic":
            self.timewarp = Timewarp_Logistic(
                "single",
                0,
                self.num_features,
                torch.tensor(sigma_min),
                torch.tensor(sigma_max),
                weight_low_noise=weight_low_noise,
                decay=0,
            )

        x_low_dim = len(categories) * cat_emb_dim
        self.mlp = HighResMLP(self.num_features, x_low_dim, emb_dim=emb_dim, n_layers=n_layers, n_units=n_units)

    @property
    def device(self):
        return next(self.mlp.parameters()).device

    def loss_fn(self, x_num, x_cat, z_num=None, mask=None):
        B = x_num.shape[0]

        if z_num is not None:
            x_low = self.emb(torch.column_stack((x_cat, z_num))).flatten(1)
        else:
            x_low = self.emb(x_cat).flatten(1)

        # draw time and convert to standard deviations for noise
        with torch.no_grad():
            t = low_discrepancy_sampler(B, device=self.device)  # (B,)
            sigma = self.timewarp.get_sigmas(t)
            t = t.to(torch.float32)
            assert sigma.shape == (B, self.num_features)

        # add noise
        x_t = x_num + torch.randn_like(x_num) * sigma

        # pass to score model
        preds = self.precondition(x_t, t, sigma, x_low)
        mse_losses = (preds - x_num) ** 2

        if mask is not None:
            obs_mask = ~mask

        def avg_loss_with_mask(loss, obs_mask):
            return (loss * obs_mask).sum(1) / (obs_mask.sum(1) + 1e-8)

        weight = (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2 + 1e-7)
        weighted_mse_losses = weight * mse_losses

        timewarp_target = mse_losses.detach()
        timewarp_losses = (
            self.timewarp.loss_fn(sigma.squeeze(-1).detach(), timewarp_target, obs_mask)
            if mask is not None
            else self.timewarp.loss_fn(sigma.squeeze(-1).detach(), timewarp_target)
        )

        time_reweight = self.weightnet(t).unsqueeze(1)
        time_reweight_target = (
            avg_loss_with_mask(weighted_mse_losses, obs_mask) if mask is not None else weighted_mse_losses.mean(1)
        )
        time_reweight_loss = (time_reweight.exp() - time_reweight_target.detach().unsqueeze(1)) ** 2

        train_loss = (
            (avg_loss_with_mask(weighted_mse_losses, mask).mean() if mask is not None else weighted_mse_losses.mean())
            + timewarp_losses.mean()
            + time_reweight_loss.mean()
        )

        return train_loss

    def precondition(self, x_t, t, sigma, x_low):
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
        c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()  # B, num_features, 1
        preds = self.mlp(c_in * x_t, x_low, t)  # c_noise is included in score model
        D_x = c_skip * x_t + c_out * preds

        return D_x

    @torch.inference_mode()
    def sampler(self, x_cat, z_num, num_steps=200):
        B = x_cat.shape[0]
        latents = torch.randn((B, self.num_features), device=self.device)

        if z_num is not None:
            x_low = self.emb(torch.column_stack((x_cat, z_num))).flatten(1)
        else:
            x_low = self.emb(x_cat).flatten(1)

        # construct time steps
        t_steps = torch.linspace(1, 0, num_steps + 1, device=self.device, dtype=torch.float64)
        s_steps = self.timewarp.get_sigmas(t_steps).to(torch.float64)  # (B, num_features)

        assert torch.allclose(s_steps[0].to(torch.float32), self.timewarp.sigma_max.float())
        assert torch.allclose(s_steps[-1].to(torch.float32), self.timewarp.sigma_min.float())

        # initialize latents at maximum noise level
        x_next = latents.to(torch.float64) * s_steps[0]

        for i, (s_cur, s_next, t_cur) in enumerate(zip(s_steps[:-1], s_steps[1:], t_steps[:-1])):
            s_cur = s_cur.repeat((B, 1))
            s_next = s_next.repeat((B, 1))

            # estimate scores
            pred = self.precondition(
                x_next.to(torch.float32),
                t_cur.to(torch.float32).repeat((B,)),
                s_cur.to(torch.float32),
                x_low,
            )
            d_cur = (x_next - pred.to(torch.float64)) / s_cur

            # adjust data samples
            h = s_next - s_cur
            x_next = x_next + h * d_cur

        return x_next.cpu()

    @torch.inference_mode()
    def sample_data(self, x_cat, z_num, num_steps=200, batch_size=4096, seed=42, verbose=True):
        set_seeds(seed, cuda_deterministic=True)
        n_batches, remainder = divmod(x_cat.shape[0], batch_size)
        sample_sizes = n_batches * [batch_size] + [remainder] if remainder != 0 else n_batches * [batch_size]
        x_cat_parts = torch.split(x_cat, sample_sizes, dim=0)

        if z_num is not None:
            z_num_parts = torch.split(z_num, sample_sizes, dim=0)

        x = []
        for i in tqdm(range(len(sample_sizes)), disable=(not verbose)):
            x_cat_part = x_cat_parts[i].to(self.device)
            z_num_part = z_num_parts[i].to(self.device) if z_num is not None else None
            x_gen = self.sampler(x_cat_part, z_num_part, num_steps=num_steps)
            x.append(x_gen)
        x = torch.cat(x).cpu()

        return x
