import argparse
import math

import onnx
import torch
from torch import nn
from torch.utils.data import DataLoader
from onnxsim import simplify
from dataset import *
from environment import *


set_env()
torch.set_num_threads(30)
set_seed(42)


def optimize_onnx(input_path, output_path):
    model = onnx.load(input_path)

    model_optimized, check = simplify(model)

    assert check, "Simplified ONNX model could not be validated"

    onnx.save(model_optimized, output_path)

    print(f"Optimized model saved to {output_path}")


def _pos_encoding(time_idx, output_dim, device="cpu"):
    t = time_idx
    D = output_dim
    v = torch.zeros(D, device=device)

    i = torch.arange(0, D, device=device).float()
    div_term = torch.exp(i / D * math.log(10000))

    angle = t / div_term
    v = torch.where(i % 2 == 0, torch.sin(angle), torch.cos(angle))

    return v


def pos_encoding(timesteps, output_dim, device="cpu"):
    batch_size = timesteps.shape[0]
    device = timesteps.device
    v = torch.zeros(batch_size, output_dim, device=device)
    for i in range(batch_size):
        v[i] = _pos_encoding(timesteps[i], output_dim, device)
    return v


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_embed_dim):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU()
        )
        self.mlp = nn.Sequential(
            nn.Linear(time_embed_dim, in_ch), nn.ReLU(), nn.Linear(in_ch, in_ch)
        )

    def forward(self, x, v):
        N, C, _, _ = x.shape
        v = self.mlp(v)
        v = v.view(N, C, 1, 1)
        y = self.convs(x + v)
        return y


class UNet(nn.Module):
    def __init__(self, in_ch=1, time_embed_dim=100):
        super().__init__()
        self.time_embed_dim = time_embed_dim

        self.down1 = ConvBlock(in_ch, 32, time_embed_dim)
        self.down2 = ConvBlock(32, 64, time_embed_dim)
        self.down3 = ConvBlock(64, 128, time_embed_dim)
        self.bot1 = ConvBlock(128, 256, time_embed_dim)
        self.up3 = ConvBlock(128 + 256, 128, time_embed_dim)
        self.up2 = ConvBlock(64 + 128, 64, time_embed_dim)
        self.up1 = ConvBlock(32 + 64, 32, time_embed_dim)
        self.out = nn.Conv2d(32, in_ch, 1)

        self.averagepool = nn.AvgPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")

    def forward(self, x, timesteps):
        v = pos_encoding(timesteps, self.time_embed_dim, x.device)

        x1 = self.down1(x, v)
        x = self.averagepool(x1)
        x2 = self.down2(x, v)
        x = self.averagepool(x2)
        x3 = self.down3(x, v)
        x = self.averagepool(x3)
        x = self.bot1(x, v)
        x = self.upsample(x)
        x = torch.cat([x, x3], dim=1)
        x = self.up3(x, v)
        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x, v)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up1(x, v)
        x = self.out(x)

        return x


class DiffusionModel(nn.Module):
    def __init__(
        self,
        test_num_timesteps,
        sampling_step,
        num_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        eta=1,
    ):
        super().__init__()
        self.eta = eta
        # compute alpha_bars
        self.num_timesteps = num_timesteps
        self.test_num_timesteps = test_num_timesteps
        self.sampling_step = sampling_step
        self.betas = torch.cat(
            [torch.zeros(1), torch.linspace(beta_start, beta_end, num_timesteps)], dim=0
        )
        self.alpha_bars = (1 - self.betas).cumprod(dim=0).view(-1, 1, 1, 1)
        self.unet = UNet()

    def forward(self, input_x):
        x, _ = self.add_noise(input_x, self.test_num_timesteps)
        seq = torch.arange(
            self.test_num_timesteps,
            -1,
            -self.sampling_step,
            device="cpu",
        )

        seq_next = torch.full((seq.shape[0],), -1, device="cpu")
        seq_next[:-1] = seq[1:]
        B = x.size(0)

        for i in range(seq.shape[0]):
            t = torch.full((B,), seq[i], device="cpu", dtype=torch.long)
            next_t = torch.full((B,), seq_next[i], device="cpu", dtype=torch.long)
            at = self.alpha_bars[t + 1]
            at_next = self.alpha_bars[next_t + 1]

            eps_hat = self.unet(x, t)

            x0_t = (x - eps_hat * torch.sqrt(1 - at)) / torch.sqrt(at)
            c1 = self.eta * torch.sqrt((1 - at / at_next) * (1 - at_next) / (1 - at))
            c2 = torch.sqrt(1 - at_next - c1**2)
            x = (
                torch.sqrt(at_next) * x0_t
                + c1 * torch.randn_like(x, device="cpu")
                + c2 * eps_hat
            )
        return x

    def add_noise(self, x, t):
        eps = torch.randn_like(x, device="cpu")
        at = self.alpha_bars[t + 1]
        at = at.view(-1, 1, 1, 1)
        x_t = torch.sqrt(at) * x + torch.sqrt(1 - at) * eps
        return x_t, eps

    def predict_noise(self, x, t):
        eps_hat = self.unet(x, t)
        return eps_hat


def train(category: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if category == "syn":
        batch_size = 32
        epochs = 1000
        lr = 3e-4
        num_samples = 1000

        img_sizes = [8, 16, 32, 64]

        test_num_timesteps = 460
        sampling_step = 115

        save_dir = "../model/syn"

    for img_size in img_sizes:
        if category == "syn":
            dataset = SyntheticDataset(
                num_samples=num_samples,
                img_size=img_size,
                signal=0,
                seed=42,
            )

        train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        # Train
        device = "cpu"
        beta_start = 0.0001
        beta_end = 0.02
        num_timesteps = 1000
        diffusion_model = DiffusionModel(
            test_num_timesteps=test_num_timesteps,
            sampling_step=sampling_step,
            beta_start=beta_start,
            beta_end=beta_end,
            num_timesteps=num_timesteps,
        )

        optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=lr)

        for epoch in range(epochs):
            loss_sum = 0.0
            cnt = 0
            for x, _, _ in train_loader:
                optimizer.zero_grad()
                t = torch.randint(
                    0, diffusion_model.num_timesteps, (len(x),), device=device
                )
                betas = torch.linspace(beta_start, beta_end, num_timesteps)
                at = (1 - betas).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
                eps = torch.randn_like(x, device="cpu")
                x = torch.sqrt(at) * x + torch.sqrt(1 - at) * eps
                eps_hat = diffusion_model.predict_noise(x, t)
                loss = nn.MSELoss()(eps_hat, eps)
                loss.backward()
                optimizer.step()
                loss_sum += loss.item()
                cnt += 1

            loss = loss_sum / cnt
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss}")

        # export ONNX
        diffusion_model.eval()
        dummy_input = torch.randn(1, 1, img_size, img_size).to(device)
        torch.onnx.export(
            diffusion_model,
            dummy_input,
            f"{save_dir}/diffusion_size{img_size}_timesteps{test_num_timesteps}_step{sampling_step}.onnx",
        )

        # simplify ONNX
        optimize_onnx(
            f"{save_dir}/diffusion_size{img_size}_timesteps{test_num_timesteps}_step{sampling_step}.onnx",
            f"{save_dir}/sim_diffusion_size{img_size}_timesteps{test_num_timesteps}_step{sampling_step}.onnx",
        )


if __name__ == "__main__":
    cmdline_parser = argparse.ArgumentParser()
    cmdline_parser.add_argument("-category", "--category", type=str, default="syn")

    args, unknowns = cmdline_parser.parse_known_args()
    train(category=args.category)
