"""
Diffusion model training script.
"""

import sys
import os
import copy
import ast
import re

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import random
import math
from tqdm import tqdm
import time

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torchvision import transforms

from components.BD_encoder import ViT, ViTVAE
from components.M2C_encoder import M2CEncoder
from train_encoders import ema_update, get_dataloader
from datasets.datasets import HoleyLayerDataset, RandomRoll

from TheNoiseMustFlow.core.models import Diffusion
from TheNoiseMustFlow.core.schedulers import NoiseScheduler
from TheNoiseMustFlow.core.samplers import DDIMSampler
from TheNoiseMustFlow.trainer.custom_lr_schedulers import CosineLRScheduler
from TheNoiseMustFlow.trainer.losses import snr_weighted_mse_loss, mse_loss

from metrics import mean_average_precision_binary, dice_coeff

import matplotlib.pyplot as plt


# ---- Loss functions ----


def dice_loss(x_hat, x, eps=1e-6):
    x_hat = torch.sigmoid(x_hat)
    intersection = (x_hat * x).sum()
    return 1 - (2 * intersection) / (x_hat.sum() + x.sum() + eps)


# ---- Training step ----
def train_step(diffusion, BD_encoder, M2C_encoder, args):
    # torch.autograd.set_detect_anomaly(True)
    torch.set_float32_matmul_precision(
        "high"
    )  # For better precision with mixed precision

    if not os.path.exists(os.path.join(args.save_path, "diffusion")):
        os.makedirs(os.path.join(args.save_path, "diffusion"))

    # Create the EMA diffusion model
    diffusion_ema = copy.deepcopy(diffusion).eval().requires_grad_(False)
    if args.finetune_BD:
        BD_encoder_ema = copy.deepcopy(BD_encoder).eval().requires_grad_(False)
    if args.finetune_M2C:
        M2C_encoder_ema = copy.deepcopy(M2C_encoder).eval().requires_grad_(False)

    # Moving to device and wrapping for multi-gpu
    if args.n_gpus > 1:
        diffusion = torch.nn.DataParallel(diffusion)
    diffusion = diffusion.to(args.device)
    diffusion_ema = diffusion_ema.to(args.device)
    BD_encoder = BD_encoder.to(args.device)
    if args.finetune_BD:
        BD_encoder_ema = BD_encoder_ema.to(args.device)
    M2C_encoder = M2C_encoder.to(args.device)
    if args.finetune_M2C:
        M2C_encoder_ema = M2C_encoder_ema.to(args.device)

    # Dataset
    h5_files = [
        os.path.join(args_encoders.data_path, f)
        for f in os.listdir(args_encoders.data_path)
        if f.endswith(".h5")
    ]

    # Random shuffle the list
    torch.manual_seed(args_encoders.seed)
    torch.cuda.manual_seed_all(args_encoders.seed)
    random.seed(args_encoders.seed)
    random.shuffle(h5_files)

    train_files = h5_files[: int(args_encoders.train_size * len(h5_files))]
    test_files = h5_files[int(args_encoders.train_size * len(h5_files)) :]

    train_dataset = HoleyLayerDataset(
        config_file=os.path.join(args_encoders.data_path, "config.json"),
        h5_files=train_files,
        Nx=args_encoders.Nx,
        max_seq_len=args_encoders.max_seq_len,
        scale_eps=args_encoders.scale_eps,
        binarize_BD=args_encoders.binarize_BD,
        clip_limit=args_encoders.clip_limit,
        threshold_size=args_encoders.threshold_size,
        kernel_size=args_encoders.kernel_size,
        additional_transforms={
            "pixmap_layers": [
                RandomRoll()
                if args_encoders.random_roll
                else transforms.Lambda(lambda x: x)
            ],
            "R": [transforms.ToTensor()],
            "T": [transforms.ToTensor()],
            "BD": [transforms.ToTensor()],
        },
        random_air_layers=args_encoders.random_air_layers,
    )
    test_dataset = HoleyLayerDataset(
        config_file=os.path.join(args_encoders.data_path, "config.json"),
        h5_files=test_files,
        Nx=args_encoders.Nx,
        max_seq_len=args_encoders.max_seq_len,
        scale_eps=args_encoders.scale_eps,
        binarize_BD=args_encoders.binarize_BD,
        clip_limit=args_encoders.clip_limit,
        threshold_size=args_encoders.threshold_size,
        kernel_size=args_encoders.kernel_size,
        additional_transforms={
            "pixmap_layers": [transforms.Lambda(lambda x: x)],
            "R": [transforms.ToTensor()],
            "T": [transforms.ToTensor()],
            "BD": [transforms.ToTensor()],
        },
        random_air_layers=False,  # args_encoders.random_air_layers,
    )
    train_loader = get_dataloader(train_dataset, args)
    test_loader = get_dataloader(test_dataset, args)

    # Create parameter groups for the optimizer
    named_parameters = list(diffusion.named_parameters())
    if args.finetune_BD:
        named_parameters += list(BD_encoder.named_parameters())
    if args.finetune_M2C:
        named_parameters += list(M2C_encoder.named_parameters())
    decay, no_decay = [], []
    for n, p in named_parameters:
        if p.requires_grad:
            # remove positional encodings
            if "pe" in n.lower() or "pos_" in n.lower():
                no_decay.append(p)
            elif (
                p.ndim >= 2 and ("norm" not in n.lower()) and ("bias" not in n.lower())
            ):
                decay.append(p)
            else:
                no_decay.append(p)

    optim_groups = [
        {"params": decay, "weight_decay": args.weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]
    optimizer = torch.optim.AdamW(
        optim_groups,
        lr=args.lr,
        weight_decay=args.weight_decay,
        eps=1e-8,
    )
    scheduler = CosineLRScheduler(
        optimizer,
        base_lr=args.lr,
        total_epochs=args.epochs,
        warmup_epochs=args.warmup_epochs,
    )

    noise_scheduler = NoiseScheduler(
        steps=args.steps, betas=args.betas, schedule=args.schedule
    ).to(args.device)

    if args.use_tensorboard:
        from torch.utils.tensorboard import SummaryWriter

        writer = SummaryWriter(log_dir=os.path.join(args.save_path, "diffusion"))

    # Training loop
    scaler = torch.amp.GradScaler(enabled=args.use_amp)
    generator = torch.Generator(device=args.device)
    sampler = DDIMSampler(
        noise_scheduler, steps=args.DDIM_steps, eta=0.0, use_tqdm=False
    )
    best_loss = float("inf")

    for epoch in range(args.epochs):
        #
        # Training
        #

        m_start, m_end = 0.99, 0.9998
        m_ema = m_end - (m_end - m_start) * 0.5 * (
            1 + math.cos(math.pi * epoch / max(1, args.epochs - 1))
        )

        diffusion.train()
        BD_encoder.train() if args.finetune_BD else BD_encoder.eval()
        M2C_encoder.train() if args.finetune_M2C else M2C_encoder.eval()

        train_loss = {
            "noise": 0.0,
            "sampler": 0.0 if args.sampler_loss else None,
            "sim": 0.0 if args.cosine_sim_loss else None,
            "total": 0.0,
        }

        pbar = tqdm(
            train_loader,
            desc=f"Epoch {epoch + 1}/{args.epochs} - Train",
            unit=" batch",
            disable=not args.use_tqdm,
        )
        for batch in pbar:
            optimizer.zero_grad(set_to_none=True)

            structures = batch["pixmap_layers"].to(
                args.device, non_blocking=True
            )  # (B, n_layers, Nx, Nx)
            thicknesses = batch["thicknesses"].to(
                args.device, non_blocking=True
            )  # (B, n_layers)
            masks = batch["key_padding_mask"].to(
                args.device, non_blocking=True
            )  # (B, n_layers)
            BD = batch["BD"].to(args.device, non_blocking=True)  # (B, 1, 256, 128)

            with torch.amp.autocast(enabled=args.use_amp, device_type=args.device):
                #
                # === 1) Get the conditioning vector from the M2C encoder
                #

                with torch.no_grad() if not args.finetune_M2C else torch.enable_grad():
                    q_M2C = M2C_encoder(structures, thicknesses, masks, pool=False)

                #
                # === 2) Get the latent representation from the BD encoder
                #

                with torch.no_grad():
                    latent_BD, _, _, _ = BD_encoder.encode(BD)

                #
                # === 3) Diffusion model forward pass
                #

                noise = torch.randn(
                    latent_BD.size(), device=args.device, generator=generator
                )
                t = torch.randint(
                    0,
                    noise_scheduler.steps,
                    (latent_BD.size(0),),
                    device=args.device,
                    generator=generator,
                )
                snr = noise_scheduler.compute_snr(t)
                noisy_latent = noise_scheduler.add_noise_cumulative(latent_BD, t, noise)

                pred_noise = diffusion(noisy_latent, t, q_M2C)

                if args.loss_function == "snr_weighted_mse_loss":
                    loss_noise = snr_weighted_mse_loss(
                        noise, pred_noise, snr, reduction="mean"
                    )
                elif args.loss_function == "mse_loss":
                    loss_noise = mse_loss(noise, pred_noise, reduction="mean")
                else:
                    raise ValueError(f"Unknown loss function: {args.loss_function}")

                #
                # === 4) (Optional) Diffusion sampler loss
                #

                if args.sampler_loss or args.cosine_sim_loss:
                    alpha_cumprod_t = sampler.noise_scheduler.alphas_cumprod[t]
                    alpha_cumprod_t = alpha_cumprod_t.view(
                        -1, *((1,) * (noisy_latent.ndim - 1))
                    )
                    pred_latent = (
                        noisy_latent - torch.sqrt(1.0 - alpha_cumprod_t) * pred_noise
                    ) / torch.sqrt(alpha_cumprod_t)

                    with (
                        torch.no_grad() if not args.finetune_BD else torch.enable_grad()
                    ):
                        rec_BD = BD_encoder.decode(pred_latent)

                if args.sampler_loss:
                    if args.sampler_loss_function == "snr_weighted_mse_loss":
                        loss_sampler = snr_weighted_mse_loss(
                            BD, rec_BD, snr, reduction="mean"
                        )
                    elif args.sampler_loss_function == "mse_loss":
                        loss_sampler = mse_loss(BD, rec_BD, reduction="mean")
                    elif args.sampler_loss_function == "BCE_loss":
                        bce = torch.nn.BCEWithLogitsLoss()
                        loss_sampler = bce(rec_BD, BD)
                    elif args.sampler_loss_function == "dice_loss":
                        loss_sampler = dice_loss(rec_BD, BD)
                    else:
                        raise ValueError(
                            f"Unknown sampler loss function: {args.sampler_loss_function}"
                        )

                    loss_sampler = args.sampler_loss_weight * loss_sampler

                #
                # === 5) (Optional) Cosine similarity loss
                #

                if args.cosine_sim_loss:
                    with torch.no_grad():
                        _, _, _, q_BD = BD_encoder.encode(rec_BD)

                    cos_sim = F.cosine_similarity(q_M2C[:, 0], q_BD, dim=-1, eps=1e-6)
                    loss_sim = args.cosine_sim_loss_weight * (1.0 - cos_sim.mean())

                #
                # === 6) Combine losses
                #

                loss = (
                    loss_noise
                    + (loss_sampler if args.sampler_loss else 0.0)
                    + (loss_sim if args.cosine_sim_loss else 0.0)
                )

                if args.n_gpus > 1:
                    loss = loss.mean()
                    loss_noise = loss_noise.mean()
                    if args.sampler_loss:
                        loss_sampler = loss_sampler.mean()
                    if args.cosine_sim_loss:
                        loss_sim = loss_sim.mean()

            scaler.scale(loss).backward()

            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(
                list(diffusion.parameters()) + list(BD_encoder.parameters())
                if args.finetune_BD
                else [] + list(M2C_encoder.parameters())
                if args.finetune_M2C
                else [],
                max_norm=5.0,
            )

            # Optimizer step
            scaler.step(optimizer)
            scaler.update()

            # Update the EMA model if needed
            ema_update(diffusion, diffusion_ema, m=m_ema)
            if args.finetune_BD:
                ema_update(BD_encoder, BD_encoder_ema, m=m_ema)
            if args.finetune_M2C:
                ema_update(M2C_encoder, M2C_encoder_ema, m=m_ema)

            train_loss["noise"] += loss_noise.item() / len(train_loader)
            train_loss["total"] += loss.item() / len(train_loader)
            if args.sampler_loss:
                train_loss["sampler"] += loss_sampler.item() / len(train_loader)
            if args.cosine_sim_loss:
                train_loss["sim"] += loss_sim.item() / len(train_loader)

        print(
            f"Epoch {epoch + 1}/{args.epochs} - Train Loss: {train_loss['total']:.4f}"
        )

        #
        # Validation
        #

        diffusion.eval()
        BD_encoder.eval()
        M2C_encoder.eval()

        test_loss = {
            "noise": 0.0,
            "sampler": 0.0 if args.sampler_loss else None,
            "sim": 0.0 if args.cosine_sim_loss else None,
            "total": 0.0,
        }

        with torch.no_grad():
            pbar = tqdm(
                test_loader,
                desc=f"Epoch {epoch + 1}/{args.epochs} - Test",
                unit=" batch",
                disable=not args.use_tqdm,
            )
            for batch in pbar:
                structures = batch["pixmap_layers"].to(
                    args.device, non_blocking=True
                )  # (B, n_layers, Nx, Nx)
                thicknesses = batch["thicknesses"].to(
                    args.device, non_blocking=True
                )  # (B, n_layers)
                masks = batch["key_padding_mask"].to(
                    args.device, non_blocking=True
                )  # (B, n_layers)
                BD = batch["BD"].to(args.device, non_blocking=True)  # (B, 1, 256, 128)

                with torch.amp.autocast(enabled=args.use_amp, device_type=args.device):
                    #
                    # === 1) Get the conditioning vector from the M2C encoder
                    #

                    if args.finetune_M2C:
                        q_M2C = M2C_encoder_ema(
                            structures, thicknesses, masks, pool=False
                        )
                    else:
                        q_M2C = M2C_encoder(structures, thicknesses, masks, pool=False)

                    #
                    # === 2) Get the latent representation from the BD encoder
                    #

                    if args.finetune_BD:
                        latent_BD, _, _, _ = BD_encoder_ema.encode(BD)
                    else:
                        latent_BD, _, _, _ = BD_encoder.encode(BD)

                    #
                    # === 3) Diffusion model forward pass
                    #

                    noise = torch.randn(
                        latent_BD.size(), device=args.device, generator=generator
                    )
                    t = torch.randint(
                        0,
                        noise_scheduler.steps,
                        (latent_BD.size(0),),
                        device=args.device,
                        generator=generator,
                    )
                    snr = noise_scheduler.compute_snr(t)
                    noisy_latent = noise_scheduler.add_noise_cumulative(
                        latent_BD, t, noise
                    )

                    pred_noise = diffusion_ema(noisy_latent, t, q_M2C)

                    if args.loss_function == "snr_weighted_mse_loss":
                        loss_noise = snr_weighted_mse_loss(
                            noise, pred_noise, snr, reduction="mean"
                        )
                    elif args.loss_function == "mse_loss":
                        loss_noise = mse_loss(noise, pred_noise, reduction="mean")
                    else:
                        raise ValueError(f"Unknown loss function: {args.loss_function}")

                    #
                    # === 4) (Optional) Diffusion sampler loss
                    #

                    alpha_cumprod_t = sampler.noise_scheduler.alphas_cumprod[t]
                    alpha_cumprod_t = alpha_cumprod_t.view(
                        -1, *((1,) * (noisy_latent.ndim - 1))
                    )
                    pred_latent = (
                        noisy_latent - torch.sqrt(1.0 - alpha_cumprod_t) * pred_noise
                    ) / torch.sqrt(alpha_cumprod_t)

                    if args.sampler_loss or args.cosine_sim_loss:
                        if args.finetune_BD:
                            rec_BD = BD_encoder_ema.decode(pred_latent)
                        else:
                            rec_BD = BD_encoder.decode(pred_latent)

                    if args.sampler_loss:
                        if args.sampler_loss_function == "snr_weighted_mse_loss":
                            loss_sampler = snr_weighted_mse_loss(
                                BD, rec_BD, snr, reduction="mean"
                            )
                        elif args.sampler_loss_function == "mse_loss":
                            loss_sampler = mse_loss(BD, rec_BD, reduction="mean")
                        elif args.sampler_loss_function == "BCE_loss":
                            loss_sampler = bce(rec_BD, BD)
                        elif args.sampler_loss_function == "dice_loss":
                            loss_sampler = dice_loss(rec_BD, BD)
                        else:
                            raise ValueError(
                                f"Unknown sampler loss function: {args.sampler_loss_function}"
                            )

                        loss_sampler = args.sampler_loss_weight * loss_sampler

                    #
                    # === 5) (Optional) Cosine similarity loss
                    #

                    if args.cosine_sim_loss:
                        _, _, _, q_BD = BD_encoder.encode(rec_BD)

                        cos_sim = F.cosine_similarity(
                            q_M2C[:, 0], q_BD, dim=-1, eps=1e-6
                        )
                        loss_sim = args.cosine_sim_loss_weight * (1.0 - cos_sim.mean())

                    #
                    # === 6) Combine losses
                    #

                    loss = (
                        loss_noise
                        + (loss_sampler if args.sampler_loss else 0.0)
                        + (loss_sim if args.cosine_sim_loss else 0.0)
                    )

                    if args.n_gpus > 1:
                        loss = loss.mean()
                        loss_noise = loss_noise.mean()
                        if args.sampler_loss:
                            loss_sampler = loss_sampler.mean()
                        if args.cosine_sim_loss:
                            loss_sim = loss_sim.mean()

                test_loss["noise"] += loss_noise.item() / len(test_loader)
                test_loss["total"] += loss.item() / len(test_loader)
                if args.sampler_loss:
                    test_loss["sampler"] += loss_sampler.item() / len(test_loader)
                if args.cosine_sim_loss:
                    test_loss["sim"] += loss_sim.item() / len(test_loader)

        print(f"Epoch {epoch + 1}/{args.epochs} - Test Loss: {test_loss['total']:.4f}")
        scheduler.step(0.0)

        if args.use_tensorboard:
            writer.add_scalar("Train/Total", train_loss["total"], epoch + 1)
            writer.add_scalar("Test/Total", test_loss["total"], epoch + 1)
            writer.add_scalar("Train/Noise", train_loss["noise"], epoch + 1)
            writer.add_scalar("Test/Noise", test_loss["noise"], epoch + 1)
            if args.sampler_loss:
                writer.add_scalar("Train/Sampler", train_loss["sampler"], epoch + 1)
                writer.add_scalar("Test/Sampler", test_loss["sampler"], epoch + 1)
            if args.cosine_sim_loss:
                writer.add_scalar("Train/Sim", train_loss["sim"], epoch + 1)
                writer.add_scalar("Test/Sim", test_loss["sim"], epoch + 1)

            writer.add_scalar(
                "Others/Learning Rate", scheduler.get_last_lr()[0], epoch + 1
            )

            writer.add_scalar(
                "Statistics/Latent std", latent_BD.std().item(), epoch + 1
            )
            writer.add_scalar(
                "Statistics/Pred noise std", pred_noise.std().item(), epoch + 1
            )
            writer.add_scalar(
                "Statistics/Pred latent std", pred_latent.std().item(), epoch + 1
            )

        # Compute mAP and Dice coeff for the test set
        if epoch % 10 == 0 or epoch == args.epochs - 1:
            #
            # Validation using last test batch - sampling and visualization
            #

            num_samples = min(32, BD.size(0))
            vmin = min(noise.min().item(), pred_noise.min().item())
            vmax = max(noise.max().item(), pred_noise.max().item())
            fig, axes = plt.subplots(
                2,
                num_samples,
                figsize=(5 * num_samples, 10),
            )
            for i in range(num_samples):
                # Original noise
                axes[0, i].imshow(
                    noise[i].mean(dim=0).cpu().numpy().squeeze(),
                    vmin=vmin,
                    vmax=vmax,
                    cmap="grey",
                )
                axes[0, i].set_title(f"Original Noise {i + 1}")
                axes[0, i].text(
                    0.0,
                    0.0,
                    f"{str(t[i].cpu().item())}: {snr[i]}",
                    color="red",
                    fontsize=14,
                    ha="left",
                    va="top",
                )
                axes[0, i].axis("off")

                # Predicted noise
                axes[1, i].imshow(
                    pred_noise[i].mean(dim=0).cpu().numpy().squeeze(),
                    vmin=vmin,
                    vmax=vmax,
                    cmap="grey",
                )
                axes[1, i].set_title(f"Predicted Noise {i + 1}")
                axes[1, i].axis("off")
            plt.tight_layout()

            if args.use_tensorboard:
                writer.add_figure("Predicted Noise", fig, epoch + 1)
            plt.close("all")

            # Sampling with DDIM

            num_samples = min(10, BD.size(0))
            fig, axes = plt.subplots(num_samples, 11, figsize=(11 * 5, num_samples * 5))
            for j in range(num_samples):
                x = torch.randn(
                    (1, latent_BD.size(1), latent_BD.size(2), latent_BD.size(3)),
                    device=args.device,
                    generator=generator,
                )
                sampled_latent = sampler.sample(
                    x,
                    pred_noise_func=diffusion_ema,
                    func_inputs={"context": q_M2C[j, None]},
                    return_intermediates=True,
                    return_step=sampler.steps // 10,
                )

                if args.finetune_BD:
                    generated_BD = BD_encoder_ema.decode(
                        torch.cat(sampled_latent, dim=0)
                    )
                else:
                    generated_BD = BD_encoder.decode(torch.cat(sampled_latent, dim=0))

                for i in range(10):
                    axes[j, i].imshow(
                        generated_BD[i].cpu().detach().numpy().transpose(1, 2, 0),
                        cmap="grey",
                        vmin=0.0,
                        vmax=1.0,
                    )
                    axes[j, i].set_title(f"Step {i * (sampler.steps // 10)}")
                    axes[j, i].axis("off")

                # Original
                axes[j, 10].imshow(
                    BD[j].cpu().detach().numpy().squeeze(),
                    cmap="grey",
                    vmin=0.0,
                    vmax=1.0,
                )
                axes[j, 10].set_title("Original")
                axes[j, 10].axis("off")
            plt.tight_layout()
            plt.savefig(
                os.path.join(args.save_path, "diffusion", f"sampling_{epoch + 1}.png")
            )

            if args.use_tensorboard:
                writer.add_figure("DDIM Sampling", fig, epoch + 1)
            plt.close("all")

            mAP = 0.0
            dice = 0.0
            with torch.no_grad():
                pbar = tqdm(
                    test_loader,
                    desc=f"Epoch {epoch + 1}/{args.epochs} - Test",
                    unit=" batch",
                    disable=not args.use_tqdm,
                )
                for batch in pbar:
                    structures = batch["pixmap_layers"].to(
                        args.device, non_blocking=True
                    )  # (B, n_layers, Nx, Nx)
                    thicknesses = batch["thicknesses"].to(
                        args.device, non_blocking=True
                    )  # (B, n_layers)
                    masks = batch["key_padding_mask"].to(
                        args.device, non_blocking=True
                    )  # (B, n_layers)
                    BD = batch["BD"].to(
                        args.device, non_blocking=True
                    )  # (B, 1, 256, 128)

                    with torch.amp.autocast(
                        enabled=args.use_amp, device_type=args.device
                    ):
                        #
                        # === 1) Get the conditioning vector from the M2C encoder
                        #

                        if args.finetune_M2C:
                            q_M2C = M2C_encoder_ema(
                                structures, thicknesses, masks, pool=False
                            )
                        else:
                            q_M2C = M2C_encoder(
                                structures, thicknesses, masks, pool=False
                            )

                        #
                        # === 2) Sampling with DDIM
                        #

                        x = torch.randn(
                            (
                                latent_BD.size(0),
                                latent_BD.size(1),
                                latent_BD.size(2),
                                latent_BD.size(3),
                            ),
                            device=args.device,
                            generator=generator,
                        )
                        sampled_latent = sampler.sample(
                            x,
                            pred_noise_func=diffusion_ema,
                            func_inputs={"context": q_M2C},
                            return_intermediates=False,
                        )

                        if args.finetune_BD:
                            generated_BD = BD_encoder_ema.decode(sampled_latent)
                        else:
                            generated_BD = BD_encoder.decode(sampled_latent)

                        # Compute mAP and Dice coeff
                        mAP += mean_average_precision_binary(
                            generated_BD, BD, from_logits=True, reduce=True
                        ) / len(test_loader)
                        dice += dice_coeff(generated_BD, BD) / len(test_loader)

            if args.use_tensorboard:
                writer.add_scalar(
                    "Statistics/Latent DDIM std",
                    sampled_latent[-1].std().item(),
                    epoch + 1,
                )

                writer.add_scalar("Metrics/mAP", mAP, epoch + 1)
                writer.add_scalar("Metrics/Dice", dice, epoch + 1)

        # Save the model if it has the best loss so far
        if test_loss["total"] < best_loss:
            best_loss = test_loss["total"]
            torch.save(
                {
                    "diffusion": diffusion_ema.state_dict(),
                    "BD_encoder": BD_encoder_ema.state_dict()
                    if args.finetune_BD
                    else BD_encoder.state_dict(),
                    "M2C_encoder": M2C_encoder_ema.state_dict()
                    if args.finetune_M2C
                    else M2C_encoder.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                },
                os.path.join(args.save_path, "diffusion", "diffusion.pth"),
            )
        if epoch % 50 == 0 or epoch == args.epochs - 1:
            torch.save(
                {
                    "diffusion": diffusion_ema.state_dict(),
                    "BD_encoder": BD_encoder_ema.state_dict()
                    if args.finetune_BD
                    else BD_encoder.state_dict(),
                    "M2C_encoder": M2C_encoder_ema.state_dict()
                    if args.finetune_M2C
                    else M2C_encoder.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "epoch": epoch + 1,
                    "best_loss": best_loss,
                },
                os.path.join(
                    args.save_path, "diffusion", f"diffusion_epoch_{epoch + 1}.pth"
                ),
            )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Train a diffusion model for photonic structures."
    )

    # Pre-trained M2C and BD encoders
    parser.add_argument(
        "--encoders_path",
        type=str,
        required=True,
        help="Path to the pre-trained encoders.",
    )

    # Diffusion model parameters
    parser.add_argument(
        "--config_diffusion",
        type=str,
        default="../../configs/models/UNet_normal.yaml",
        help="Path to the diffusion model architecture config file.",
    )
    parser.add_argument(
        "--steps", type=int, default=1000, help="Number of diffusion steps."
    )
    parser.add_argument(
        "--betas",
        type=float,
        nargs=2,
        default=(1e-4, 0.02),
        help="Beta values for the noise scheduler.",
    )
    parser.add_argument(
        "--schedule",
        type=str,
        default="cosine",
        choices=["linear", "cosine", "sqrt", "sigmoid"],
        help="Schedule type for the noise scheduler.",
    )
    parser.add_argument(
        "--DDIM_steps",
        type=int,
        default=50,
        help="Number of DDIM steps during training.",
    )

    # Training parameters
    parser.add_argument(
        "--device", type=str, default="cuda", help="Device to use for training."
    )
    parser.add_argument(
        "--n_gpus", type=int, default=1, help="Number of GPUs to use for training."
    )
    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of training epochs."
    )
    parser.add_argument(
        "--warmup_epochs", type=int, default=20, help="Number of warmup epochs."
    )
    parser.add_argument(
        "--lr", type=float, default=1e-4, help="Learning rate for the optimizer."
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=2e-4,
        help="Weight decay for the optimizer.",
    )
    parser.add_argument(
        "--sampler_loss",
        action="store_true",
        default=False,
        help="Whether to use the sampler loss during diffusion training.",
    )
    parser.add_argument(
        "--sampler_loss_weight",
        type=float,
        default=1.0,
        help="Weight for the sampler loss during diffusion training.",
    )
    parser.add_argument(
        "--loss_function",
        type=str,
        default="snr_weighted_mse_loss",
        choices=["snr_weighted_mse_loss", "mse_loss"],
        help="Loss function to use for diffusion training.",
    )
    parser.add_argument(
        "--sampler_loss_function",
        type=str,
        default="snr_weighted_mse_loss",
        choices=["snr_weighted_mse_loss", "mse_loss", "BCE_loss", "dice_loss"],
        help="Loss function to use for the reconstruction during diffusion training.",
    )
    parser.add_argument(
        "--cosine_sim_loss",
        action="store_true",
        default=False,
        help="Whether to use cosine similarity loss during diffusion training.",
    )
    parser.add_argument(
        "--cosine_sim_loss_weight",
        type=float,
        default=1.0,
        help="Weight for the cosine similarity loss during diffusion training.",
    )
    parser.add_argument(
        "--batch_size", type=int, default=16, help="Batch size for training."
    )
    parser.add_argument(
        "--num_workers", type=int, default=4, help="Number of workers for data loading."
    )
    parser.add_argument(
        "--finetune_BD", action="store_true", help="Whether to finetune the BD encoder."
    )
    parser.add_argument(
        "--finetune_M2C",
        action="store_true",
        help="Whether to finetune the M2C encoder.",
    )

    # Saving parameters
    parser.add_argument(
        "--save_path",
        type=str,
        default=f"./checkpoints/{time.strftime('%Y%m%d_%H%M%S')}/",
        help="Path to save the trained model.",
    )
    parser.add_argument(
        "--reload_from",
        type=str,
        default=None,
        help="Path to a checkpoint to reload the model from.",
    )

    # Misc
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility."
    )
    parser.add_argument(
        "--use_tqdm", action="store_true", help="Whether to use tqdm for progress bars."
    )
    parser.add_argument(
        "--use_tensorboard",
        action="store_true",
        help="Whether to use TensorBoard for logging.",
    )
    parser.add_argument(
        "--use_amp",
        action="store_true",
        help="Whether to use automatic mixed precision.",
    )
    parser.add_argument(
        "--load_best_acc",
        action="store_true",
        help="Whether to load the best model in terms of accuracy instead of in terms of loss.",
    )

    args = parser.parse_args()

    assert os.path.exists(args.encoders_path), "Encoders path does not exist."
    with open(os.path.join(args.encoders_path, "args_CLIP.txt"), "r") as f:
        content = f.read().strip()
        content = content[len("Namespace(") : -1]
        args_encoders = ast.literal_eval(
            "{" + re.sub(r"(\w+)=\s*", r"'\1': ", content) + "}"
        )
        args_encoders = argparse.Namespace(**args_encoders)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
        # Dump the args to a text file
        with open(os.path.join(args.save_path, "args_diffusion.txt"), "w") as f:
            f.write(str(args))

    # Load the pre-trained encoders
    ViT_model = ViT(
        image_size=(256, 128),
        patch_size=args_encoders.patch_size,
        dim=args_encoders.dim,  # Should be equal to 256//8 * 128//8 = 32 * 16 = 512
        out_dim=args_encoders.out_dim,
        depth=args_encoders.depth,
        heads=args_encoders.num_heads,
        mlp_dim=args_encoders.mlp_dim,
        pool="cls",
        channels=1,
        dropout=args_encoders.dropout,
        emb_dropout=args_encoders.dropout,
        patch_dropout=args_encoders.patch_dropout,
    )
    print(f"ViT model parameters: {sum(p.numel() for p in ViT_model.parameters())}")

    BD_encoder = ViTVAE(
        vit=ViT_model,
        image_size=(256, 128),
        patch_size=args_encoders.patch_size,
        out_channels=1,
        latent_channels=args_encoders.latent_channels,
        channels=args_encoders.channels,
    )
    print(
        f"ViT-VAE encoder parameters: {sum(p.numel() for p in BD_encoder.encoder.parameters())} ;",
        f"ViT-VAE decoder parameters: {sum(p.numel() for p in BD_encoder.decoder.parameters())}",
    )

    M2C_encoder = M2CEncoder(
        spatial_dim=args_encoders.Nx,
        in_channels=1,
        latent_dim=args_encoders.m2c_latent_dim,
        channels=args_encoders.m2c_channels,
        num_vision_heads=args_encoders.m2c_num_vision_heads,
        groups=args_encoders.m2c_groups,
        vision_dropout=args_encoders.m2c_vision_dropout,
        modality=args_encoders.m2c_modality,
        use_phase=args_encoders.m2c_use_phase,
        out_dim=args_encoders.m2c_out_dim,
        max_seq_len=args_encoders.max_seq_len,
        num_heads=args_encoders.m2c_num_heads,
        num_layers=args_encoders.m2c_num_layers,
        d_ff=args_encoders.m2c_d_ff,
        dropout=args_encoders.m2c_dropout,
        kwargs_pe={
            "ini_freq_scale": 1.0,
            "tunable_freq_scale": True,
            "dropout": 0.0,
        },
        use_cls=True,
        trainable_pe=args_encoders.m2c_trainable_pe,
    )
    print(f"M2C encoder parameters: {sum(p.numel() for p in M2C_encoder.parameters())}")

    # Load the weights
    if not args.load_best_acc or not os.path.exists(
        os.path.join(args.encoders_path, "encoders", "encoders_best_acc.pth")
    ):
        encoders_path = os.path.join(args.encoders_path, "encoders", "encoders.pth")
    else:
        encoders_path = os.path.join(
            args.encoders_path, "encoders", "encoders_best_acc.pth"
        )
    checkpoint = torch.load(
        encoders_path,
        map_location="cpu",
        weights_only=True,
    )

    BD_encoder.load_state_dict(checkpoint["BD_encoder_ema_state_dict"])
    M2C_encoder.load_state_dict(checkpoint["M2C_encoder_ema_state_dict"])

    print("\n>>> Successfully loaded the pre-trained encoders.\n")

    diffusion = Diffusion(
        latent_dim=args_encoders.latent_channels, config_file=args.config_diffusion
    )
    print(
        f"Diffusion model parameters: {sum(p.numel() for p in diffusion.parameters())}"
    )
    if args.reload_from is not None:
        assert os.path.exists(args.reload_from), (
            "Checkpoint path to reload the model from does not exist."
        )
        checkpoint = torch.load(
            args.reload_from,
            map_location="cpu",
            weights_only=True,
        )
        diffusion.load_state_dict(checkpoint["diffusion"])

        print("\n>>> Successfully reloaded the diffusion model.\n")

    train_step(diffusion, BD_encoder, M2C_encoder, args)
