"""
The M2C encoder is jointly trained with a ViT-VAE to learn a latent space
that is both good for reconstruction and for contrastive learning.
This latent space is then used as a conditioning for the diffusion model in the next step.

Contrastive learning (CLIP-like) is performed on the CLS tokens of the ViTVAE encoder
(for the band diagram images) and the CLS tokens of the M2C encoder (for the material sequences).
The other tokens from the ViTVAE encoder are not used for CL, but they are used for reconstruction.
The other tokens from the M2C encoder are ignored here, but they will be used as conditioning
for the diffusion model in the next step.

Training pipeline is inspired by: https://github.com/revantteotia/clip-training/tree/main
"""

import sys
import os
import copy

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import random
import math
from tqdm import tqdm
import time

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torchvision import transforms

from components.BD_encoder import ViT, ViTVAE
from components.M2C_encoder import M2CEncoder
from datasets.datasets import HoleyLayerDataset, RandomRoll

from TheNoiseMustFlow.trainer.custom_lr_schedulers import CosineLRScheduler

import matplotlib.pyplot as plt


# ---- Loss helpers ----
def vae_kl_loss(mu, logvar):
    return 0.5 * torch.mean(
        torch.sum(mu.pow(2) + logvar.exp() - 1.0 - logvar, dim=[1, 2, 3])
    )


def reconstruction_loss(x_hat, x, mode="mse", pos_weight=None):
    if mode == "mse":
        return F.mse_loss(x_hat, x)
    if mode == "l1":
        return F.l1_loss(x_hat, x)
    if mode == "BCE":
        bce = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        return bce(x_hat, x)
    raise ValueError(mode)


# --- Dataloader helper ---
def get_dataloader(dataset, args, training=True):
    # if args.n_gpus > 1:
    #    sampler = torch.utils.data.RandomSampler(dataset)
    #    batch_size = args.batch_size * max(1, args.n_gpus)
    # else:
    #    sampler = None
    #    batch_size = args.batch_size
    sampler = None
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
        shuffle=True if (sampler is None and training) else False,
    )
    return dataloader


# ---- LogitScale helper ----
class LogitScale(torch.nn.Module):
    def __init__(self, init_value=0.07):
        super().__init__()
        self.logit_scale = torch.nn.Parameter(
            torch.ones([]) * torch.log(torch.tensor(1 / init_value))
        )

    def forward(self, x):
        return x * self.logit_scale.exp()


# ---- EMA helper ----
def ema_update(model, model_ema, m=0.999):
    src = model.module if hasattr(model, "module") else model
    with torch.no_grad():
        for p, p_ema in zip(src.parameters(), model_ema.parameters()):
            p_ema.data.mul_(m).add_(p.data, alpha=1.0 - m)


# ---- Memory Queue helper ----
class MemoryQueue(torch.nn.Module):
    """
    MoCo-style First-In, First-Out (FIFO) queue storing L2-normalized keys [K, dim].
    Keep one queue per modality
    """

    def __init__(self, dim: int, K: int = 65536):
        super().__init__()
        self.K = K
        self.register_buffer("queue", torch.randn(K, dim))
        self.register_buffer("ptr", torch.zeros(1, dtype=torch.long))
        with torch.no_grad():
            self.queue.normal_()
            self.queue.copy_(torch.nn.functional.normalize(self.queue, dim=1))

    @torch.no_grad()
    def enqueue(self, keys: torch.Tensor):
        if keys.numel() == 0:
            return

        k = F.normalize(keys.detach().to(self.queue.dtype), dim=1)
        if k.device != self.queue.device:
            k = k.to(self.queue.device, non_blocking=True)

        B = k.shape[0]
        if B >= self.K:
            # keep the most recent K keys
            k = k[-self.K :]
            B = k.shape[0]

        ptr = int(self.ptr.item())
        end = (ptr + B) % self.K

        if ptr + B <= self.K:
            self.queue[ptr : ptr + B].copy_(k)
        else:
            first = self.K - ptr
            self.queue[ptr:].copy_(k[:first])
            self.queue[:end].copy_(k[first:])

        self.ptr[0] = end


# ---- L2 Normalization helper ----
def l2_normalize(x, dim=-1, eps=1e-6):
    x = x.float()
    return x / x.norm(dim=dim, keepdim=True).clamp_min(eps)


# ---- Distributed helper ----
def is_dist_avail_and_initialized():
    return dist.is_available() and dist.is_initialized()


# ---- Metrics helper ----
@torch.no_grad()
def mean_pos_neg_inbatch(S):
    pos = S.diag()  # (B,)
    B = S.size(0)
    neg = S[~torch.eye(B, dtype=torch.bool, device=S.device)]
    return pos.mean(), neg.mean()


@torch.no_grad()
def mean_pos_neg_inbatch_ddp(q, k, concat_all_gather):
    q_all = concat_all_gather(q)
    k_all = concat_all_gather(k)
    S = q_all @ k_all.t()  # (B_total, B_total)
    pos = S.diag()
    Btot = S.size(0)
    neg = S[~torch.eye(Btot, dtype=torch.bool, device=S.device)]
    mean_pos = pos.mean()
    mean_neg = neg.mean()
    return mean_pos, mean_neg


# ---- Training step ----
def train_step(BD_encoder, M2C_encoder, logit, args):
    # torch.autograd.set_detect_anomaly(True)
    torch.set_float32_matmul_precision(
        "high"
    )  # For better precision with mixed precision

    if not os.path.exists(os.path.join(args.save_path, "encoders")):
        os.makedirs(os.path.join(args.save_path, "encoders"))

    # Create EMA models
    BD_encoder_ema = copy.deepcopy(BD_encoder).eval().requires_grad_(False)
    M2C_encoder_ema = copy.deepcopy(M2C_encoder).eval().requires_grad_(False)

    # Moving to device and wrapping for multi-gpu
    if args.n_gpus > 1:
        BD_encoder = torch.nn.DataParallel(BD_encoder)
        M2C_encoder = torch.nn.DataParallel(M2C_encoder)
        logit = torch.nn.DataParallel(logit)
    BD_encoder.to(torch.device(args.device))
    BD_encoder_ema.to(torch.device(args.device))
    M2C_encoder.to(torch.device(args.device))
    M2C_encoder_ema.to(torch.device(args.device))
    logit.to(torch.device(args.device))

    # Create the queues (one per modality)
    Q_BD = MemoryQueue(dim=args.out_dim, K=args.queue_size).to(
        torch.device(args.device)
    )
    Q_M2C = MemoryQueue(dim=args.out_dim, K=args.queue_size).to(
        torch.device(args.device)
    )

    # Dataset
    h5_files = [
        os.path.join(args.data_path, f)
        for f in os.listdir(args.data_path)
        if f.endswith(".h5")
    ]

    # Random shuffle the list
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    random.shuffle(h5_files)

    train_files = h5_files[: int(args.train_size * len(h5_files))]
    test_files = h5_files[int(args.train_size * len(h5_files)) :]

    train_dataset = HoleyLayerDataset(
        config_file=os.path.join(args.data_path, "config.json"),
        h5_files=train_files,
        Nx=args.Nx,
        max_seq_len=args.max_seq_len,
        scale_eps=args.scale_eps,
        binarize_BD=args.binarize_BD,
        clip_limit=args.clip_limit,
        threshold_size=args.threshold_size,
        kernel_size=args.kernel_size,
        additional_transforms={
            "pixmap_layers": [
                RandomRoll() if args.random_roll else transforms.Lambda(lambda x: x)
            ],
            "R": [transforms.ToTensor()],
            "T": [transforms.ToTensor()],
            "BD": [transforms.ToTensor()],
        },
        random_air_layers=args.random_air_layers,
    )
    test_dataset = HoleyLayerDataset(
        config_file=os.path.join(args.data_path, "config.json"),
        h5_files=test_files,
        Nx=args.Nx,
        max_seq_len=args.max_seq_len,
        scale_eps=args.scale_eps,
        binarize_BD=args.binarize_BD,
        clip_limit=args.clip_limit,
        threshold_size=args.threshold_size,
        kernel_size=args.kernel_size,
        additional_transforms={
            "pixmap_layers": [transforms.Lambda(lambda x: x)],
            "R": [transforms.ToTensor()],
            "T": [transforms.ToTensor()],
            "BD": [transforms.ToTensor()],
        },
        random_air_layers=False,
    )
    train_loader = get_dataloader(train_dataset, args, training=True)
    test_loader = get_dataloader(test_dataset, args, training=False)

    # Create parameter groups for the optimizer
    decay, no_decay = [], []
    for n, p in (
        list(BD_encoder.named_parameters())
        + list(M2C_encoder.named_parameters())
        + list(logit.named_parameters())
    ):
        if p.requires_grad:
            # remove positional encodings
            if "pe" in n.lower() or "pos_" in n.lower():
                no_decay.append(p)
            elif (
                p.ndim >= 2 and ("norm" not in n.lower()) and ("bias" not in n.lower())
            ):
                decay.append(p)
            else:
                no_decay.append(p)

    optim_groups = [
        {"params": decay, "weight_decay": args.weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]
    optimizer = torch.optim.AdamW(
        optim_groups,
        lr=args.lr,
        weight_decay=args.weight_decay,
        eps=1e-8,
    )
    scheduler = CosineLRScheduler(
        optimizer,
        base_lr=args.lr,
        total_epochs=args.epochs,
        warmup_epochs=args.warmup_epochs,
    )

    if args.use_tensorboard:
        from torch.utils.tensorboard import SummaryWriter

        writer = SummaryWriter(log_dir=os.path.join(args.save_path, "encoders"))

    # Training loop
    scaler = torch.amp.GradScaler(enabled=args.use_amp)
    best_loss = float("inf")
    best_acc = 0.0

    optimizer.zero_grad(set_to_none=True)
    for epoch in range(args.epochs):
        enable_CLIP = not (args.queue_warmup and epoch == 0)

        if args.use_kl_warmup:
            # Linear warmup for beta_kl
            if epoch < args.warmup_epochs:
                beta_kl = (epoch + 1) / args.warmup_epochs * args.beta_kl
            else:
                beta_kl = args.beta_kl
        else:
            beta_kl = args.beta_kl

        if args.ramp_ema:
            m_start, m_end = 0.99, 0.9998
            m_ema = m_end - (m_end - m_start) * 0.5 * (
                1 + math.cos(math.pi * epoch / max(1, args.epochs - 1))
            )
        else:
            m_ema = 0.999

        if args.clip_warmup:
            if epoch < args.warmup_epochs:
                clip_weight = (epoch + 1) / args.warmup_epochs * args.beta_clip
            else:
                clip_weight = args.beta_clip
        else:
            clip_weight = args.beta_clip

        #
        # Training
        #

        global_step = 0

        BD_encoder.train()
        BD_encoder_ema.eval()
        M2C_encoder.train()
        M2C_encoder_ema.eval()
        logit.train()

        train_loss = {
            # For the contrastive loss between the BD encoder and the M2C encoder
            "CLIP": 0.0,
            # For the reconstruction loss of the BD encoder
            "recon": 0.0,
            "kl": 0.0,
            # For the total loss
            "total": 0.0,
        }
        train_top1_acc = 0.0
        train_top5_acc = 0.0
        train_pos_cosine_sim = 0.0
        train_neg_cosine_sim = 0.0

        pbar = tqdm(
            enumerate(train_loader),
            desc=f"Epoch {epoch + 1}/{args.epochs} - Train",
            unit=" batch",
            disable=not args.use_tqdm,
        )
        for step, batch in pbar:
            with torch.amp.autocast(enabled=args.use_amp, device_type=args.device):
                BD = batch["BD"].to(args.device, non_blocking=True)  # (B, 1, H, W)
                structures = batch["pixmap_layers"].to(
                    args.device, non_blocking=True
                )  # (B, L, Nx, Nx)
                thicknesses = batch["thicknesses"].to(
                    args.device, non_blocking=True
                )  # (B, L)
                masks = batch["key_padding_mask"].to(
                    args.device, non_blocking=True
                )  # (B, L)

                #
                # === 1) Forward queries with ONLINE encoders (grad enabled) ===
                #

                # Encode the band diagram
                x_hat, mu, logvar, _, q_BD = BD_encoder(BD)
                q_BD = l2_normalize(q_BD, dim=-1)

                # Encode the material structure
                q_M2C = M2C_encoder(
                    structures, thicknesses, masks, pool=True
                )  # (B, out_dim)
                q_M2C = l2_normalize(q_M2C, dim=-1)

                #
                # === 2) Forward keys with EMA encoders (no grad) ===
                #

                with torch.no_grad():
                    # Encode the band diagram
                    _, _, _, _, k_BD = BD_encoder_ema(BD)
                    k_BD = l2_normalize(k_BD, dim=-1)

                    # Encode the material structure
                    k_M2C = M2C_encoder_ema(structures, thicknesses, masks, pool=True)
                    k_M2C = l2_normalize(k_M2C, dim=-1)

                #
                # === 3) For numerical safety with AMP, do matmuls in float32 ===
                #

                q_BD_f = q_BD.float()
                q_M2C_f = q_M2C.float()
                k_BD_f = k_BD.float()
                k_M2C_f = k_M2C.float()
                Q_BD_f = Q_BD.queue.float()
                Q_M2C_f = Q_M2C.queue.float()

                #
                # === 4) Build logits with MoCo-style forumlation ===
                #

                if False:  # epoch < args.warmup_epochs:
                    with torch.no_grad():
                        if args.n_gpus <= 1:
                            logit.logit_scale.copy_(torch.log(torch.tensor(1 / 0.07)))
                        else:
                            logit.module.logit_scale.copy_(
                                torch.log(torch.tensor(1 / 0.07))
                            )

                # BD -> M2C  (query is BD, positive key is same-sample M2C key)
                l_pos_1 = torch.sum(q_BD_f * k_M2C_f, dim=-1, keepdim=True)  # (B, 1)
                l_neg_1 = q_BD_f @ Q_M2C_f.t()  # (B, K)

                if args.filter_negatives and epoch > args.warmup_epochs:
                    k = int(0.9 * l_neg_1.size(1))
                    vals, _ = torch.topk(l_neg_1, k, dim=1, largest=True)
                    l_neg_1 = vals

                logits_1 = logit(torch.cat([l_pos_1, l_neg_1], dim=1))  # (B, 1+K)
                # logits_1 = torch.cat([l_pos_1, l_neg_1], dim=1) * torch.tensor(
                #    1 / 0.07
                # ).to(args.device)

                if args.normalize_logits:
                    logits_1 = logits_1 - logits_1.max(dim=1, keepdim=True).values

                labels = torch.zeros(
                    logits_1.size(0), dtype=torch.long, device=logits_1.device
                )

                # M2C -> BD  (query is M2C, positive key is same-sample BD key)
                l_pos_2 = torch.sum(q_M2C_f * k_BD_f, dim=-1, keepdim=True)  # (B, 1)
                l_neg_2 = q_M2C_f @ Q_BD_f.t()  # (B, K)

                if args.filter_negatives and epoch > args.warmup_epochs:
                    k = int(0.9 * l_neg_2.size(1))
                    vals, _ = torch.topk(l_neg_2, k, dim=1, largest=True)
                    l_neg_2 = vals

                logits_2 = logit(torch.cat([l_pos_2, l_neg_2], dim=1))  # (B, 1+K)
                # logits_2 = torch.cat([l_pos_2, l_neg_2], dim=1) * torch.tensor(
                #    1 / 0.07
                # ).to(args.device)

                if args.normalize_logits:
                    logits_2 = logits_2 - logits_2.max(dim=1, keepdim=True).values

                # Concatenate the logits and labels
                loss_CLIP = 0.5 * (
                    F.cross_entropy(
                        logits_1, labels, label_smoothing=args.label_smoothing
                    )
                    + F.cross_entropy(
                        logits_2, labels, label_smoothing=args.label_smoothing
                    )
                )

                # Compute the metrics
                S1 = q_BD_f @ k_M2C_f.t()
                S2 = q_M2C_f @ k_BD_f.t()

                mean_pos_1, mean_neg_1 = mean_pos_neg_inbatch(S1)
                mean_pos_2, mean_neg_2 = mean_pos_neg_inbatch(S2)
                mean_pos = 0.5 * (mean_pos_1 + mean_pos_2)
                mean_neg = 0.5 * (mean_neg_1 + mean_neg_2)
                train_pos_cosine_sim += mean_pos.item() / len(train_loader)
                train_neg_cosine_sim += mean_neg.item() / len(train_loader)

                labels = torch.arange(S1.size(0), device=S1.device)

                train_top1_acc += (
                    ((S1.argmax(dim=1) == labels).float().mean().item())
                    / len(train_loader)
                    / 2
                )
                train_top1_acc += (
                    ((S2.argmax(dim=1) == labels).float().mean().item())
                    / len(train_loader)
                    / 2
                )
                train_top5_acc += (
                    (
                        (S1.topk(5, dim=1).indices == labels.unsqueeze(1))
                        .any(dim=1)
                        .float()
                        .mean()
                        .item()
                    )
                    / len(train_loader)
                    / 2
                )
                train_top5_acc += (
                    (
                        (S2.topk(5, dim=1).indices == labels.unsqueeze(1))
                        .any(dim=1)
                        .float()
                        .mean()
                        .item()
                    )
                    / len(train_loader)
                    / 2
                )

                # Compute the reconstruction loss
                loss_recon = reconstruction_loss(
                    x_hat, BD, mode=args.VAE_reconstruction_loss
                )
                loss_kl = vae_kl_loss(mu, logvar)

                if enable_CLIP:
                    loss = loss_recon + beta_kl * loss_kl + clip_weight * loss_CLIP
                else:
                    loss = loss_recon + beta_kl * loss_kl

                if args.n_gpus > 1:
                    loss_CLIP = loss_CLIP.mean()
                    loss_recon = loss_recon.mean()
                    loss_kl = loss_kl.mean()
                    loss = loss.mean()
                if args.gradient_accumulation_steps > 1:
                    loss_CLIP = loss_CLIP / args.gradient_accumulation_steps
                    loss_recon = loss_recon / args.gradient_accumulation_steps
                    loss_kl = loss_kl / args.gradient_accumulation_steps
                    loss = loss / args.gradient_accumulation_steps

            scaler.scale(loss).backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                global_step += 1

                # Clip the gradients
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(
                    list(BD_encoder.parameters())
                    + list(M2C_encoder.parameters())
                    + list(logit.parameters()),
                    max_norm=1.0,
                )

                # Optimizer step
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

                # Update the EMA models
                ema_update(BD_encoder, BD_encoder_ema, m=m_ema)
                ema_update(M2C_encoder, M2C_encoder_ema, m=m_ema)

                # Enqueue the gathered keys
                with torch.no_grad():
                    Q_BD.enqueue(k_BD)
                    Q_M2C.enqueue(k_M2C)

                # Clip the logits scale
                if args.n_gpus <= 1:
                    logit.logit_scale.data = torch.clamp(logit.logit_scale.data, 0, 3.4)
                    current_logit_scale = logit.logit_scale.data.exp().item()
                else:
                    logit.module.logit_scale.data = torch.clamp(
                        logit.module.logit_scale.data, 0, 3.4
                    )
                    current_logit_scale = logit.module.logit_scale.data.exp().item()

            train_loss["CLIP"] += loss_CLIP.item() * args.beta_clip
            train_loss["recon"] += loss_recon.item()
            train_loss["kl"] += loss_kl.item() * beta_kl
            train_loss["total"] += loss.item()

        train_loss["CLIP"] /= global_step
        train_loss["recon"] /= global_step
        train_loss["kl"] /= global_step
        train_loss["total"] /= global_step

        print(
            f"Epoch {epoch + 1}/{args.epochs} - Train Loss: {train_loss['total']:.4f}"
        )

        #
        # Validation
        #

        BD_encoder.eval()
        BD_encoder_ema.eval()
        M2C_encoder.eval()
        M2C_encoder_ema.eval()
        logit.eval()

        test_loss = {
            "CLIP": 0.0,
            "recon": 0.0,
            "kl": 0.0,
            "total": 0.0,
        }
        test_top1_acc = 0.0
        test_top5_acc = 0.0
        test_pos_cosine_sim = 0.0
        test_neg_cosine_sim = 0.0

        with torch.no_grad():
            pbar = tqdm(
                enumerate(test_loader),
                desc=f"Epoch {epoch + 1}/{args.epochs} - Test",
                unit=" batch",
                disable=not args.use_tqdm,
            )
            for step, batch in pbar:
                with torch.amp.autocast(enabled=args.use_amp, device_type=args.device):
                    BD = batch["BD"].to(args.device, non_blocking=True)
                    structures = batch["pixmap_layers"].to(
                        args.device, non_blocking=True
                    )
                    thicknesses = batch["thicknesses"].to(
                        args.device, non_blocking=True
                    )
                    masks = batch["key_padding_mask"].to(args.device, non_blocking=True)

                    # Encode the band diagram
                    x_hat, mu, logvar, _, cls_BD = BD_encoder_ema(BD)
                    cls_BD = l2_normalize(cls_BD, dim=-1)

                    # Encode the material structure
                    cls_M2C = M2C_encoder_ema(
                        structures, thicknesses, masks, pool=True
                    )  # (B, out_dim)
                    cls_M2C = l2_normalize(cls_M2C, dim=-1)

                    # Compute the contrastive loss
                    cls_BD_f = cls_BD.float()
                    cls_M2C_f = cls_M2C.float()

                    logits_per_BD = logit(cls_BD_f @ cls_M2C_f.t())  # (B, B)
                    # logits_per_BD = (cls_BD_f @ cls_M2C_f.t()) * torch.tensor(
                    #    1 / 0.07
                    # ).to(args.device)
                    logits_per_M2C = logit(cls_M2C_f @ cls_BD_f.t())  # (B, B)
                    # logits_per_M2C = (cls_M2C_f @ cls_BD_f.t()) * torch.tensor(
                    #    1 / 0.07
                    # ).to(args.device)

                    if args.normalize_logits:
                        logits_per_BD = (
                            logits_per_BD
                            - logits_per_BD.max(dim=1, keepdim=True).values
                        )
                        logits_per_M2C = (
                            logits_per_M2C
                            - logits_per_M2C.max(dim=1, keepdim=True).values
                        )

                    labels = torch.arange(
                        logits_per_BD.size(0), device=logits_per_BD.device
                    )

                    loss_CLIP = 0.5 * (
                        F.cross_entropy(
                            logits_per_BD, labels, label_smoothing=args.label_smoothing
                        )
                        + F.cross_entropy(
                            logits_per_M2C, labels, label_smoothing=args.label_smoothing
                        )
                    )

                    # Compute the metrics
                    S1 = cls_BD_f @ cls_M2C_f.t()
                    S2 = cls_M2C_f @ cls_BD_f.t()

                    mean_pos_1, mean_neg_1 = mean_pos_neg_inbatch(S1)
                    mean_pos_2, mean_neg_2 = mean_pos_neg_inbatch(S2)
                    mean_pos = 0.5 * (mean_pos_1 + mean_pos_2)
                    mean_neg = 0.5 * (mean_neg_1 + mean_neg_2)
                    test_pos_cosine_sim += mean_pos.item() / len(test_loader)
                    test_neg_cosine_sim += mean_neg.item() / len(test_loader)

                    test_top1_acc += (
                        ((S1.argmax(dim=1) == labels).float().mean().item())
                        / len(test_loader)
                        / 2
                    )
                    test_top1_acc += (
                        ((S2.argmax(dim=1) == labels).float().mean().item())
                        / len(test_loader)
                        / 2
                    )
                    test_top5_acc += (
                        (
                            (S1.topk(5, dim=1).indices == labels.unsqueeze(1))
                            .any(dim=1)
                            .float()
                            .mean()
                            .item()
                        )
                        / len(test_loader)
                        / 2
                    )
                    test_top5_acc += (
                        (
                            (S2.topk(5, dim=1).indices == labels.unsqueeze(1))
                            .any(dim=1)
                            .float()
                            .mean()
                            .item()
                        )
                        / len(test_loader)
                        / 2
                    )

                    # Compute the reconstruction loss
                    loss_recon = reconstruction_loss(
                        x_hat, BD, mode=args.VAE_reconstruction_loss
                    )
                    loss_kl = vae_kl_loss(mu, logvar)

                    if enable_CLIP:
                        loss = loss_recon + beta_kl * loss_kl + clip_weight * loss_CLIP
                    else:
                        loss = loss_recon + beta_kl * loss_kl

                    if args.n_gpus > 1:
                        loss_CLIP = loss_CLIP.mean()
                        loss_recon = loss_recon.mean()
                        loss_kl = loss_kl.mean()
                        loss = loss.mean()

                test_loss["CLIP"] += loss_CLIP.item() * args.beta_clip
                test_loss["recon"] += loss_recon.item()
                test_loss["kl"] += loss_kl.item() * beta_kl
                test_loss["total"] += loss.item()

        test_loss["CLIP"] /= len(test_loader)
        test_loss["recon"] /= len(test_loader)
        test_loss["kl"] /= len(test_loader)
        test_loss["total"] /= len(test_loader)

        print(f"Epoch {epoch + 1}/{args.epochs} - Test Loss: {test_loss['total']:.4f}")
        scheduler.step(0.0)

        # Save the model if the test loss is the best we've seen so far.
        if test_loss["total"] < best_loss:
            best_loss = test_loss["total"]
            if args.n_gpus > 1:
                torch.save(
                    {
                        "BD_encoder_state_dict": BD_encoder.module.state_dict(),
                        "BD_encoder_ema_state_dict": BD_encoder_ema.state_dict(),
                        "M2C_encoder_state_dict": M2C_encoder.module.state_dict(),
                        "M2C_encoder_ema_state_dict": M2C_encoder_ema.state_dict(),
                        "logit_state_dict": logit.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch,
                        "best_loss": best_loss,
                    },
                    os.path.join(args.save_path, "encoders", "encoders.pth"),
                )
            else:
                torch.save(
                    {
                        "BD_encoder_state_dict": BD_encoder.state_dict(),
                        "BD_encoder_ema_state_dict": BD_encoder_ema.state_dict(),
                        "M2C_encoder_state_dict": M2C_encoder.state_dict(),
                        "M2C_encoder_ema_state_dict": M2C_encoder_ema.state_dict(),
                        "logit_state_dict": logit.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch,
                        "best_loss": best_loss,
                    },
                    os.path.join(args.save_path, "encoders", "encoders.pth"),
                )
            print(f"Epoch {epoch + 1}: Best model saved with loss {best_loss:.4f}")

        if test_top1_acc > best_acc:
            best_acc = test_top1_acc
            if args.n_gpus > 1:
                torch.save(
                    {
                        "BD_encoder_state_dict": BD_encoder.module.state_dict(),
                        "BD_encoder_ema_state_dict": BD_encoder_ema.state_dict(),
                        "M2C_encoder_state_dict": M2C_encoder.module.state_dict(),
                        "M2C_encoder_ema_state_dict": M2C_encoder_ema.state_dict(),
                        "logit_state_dict": logit.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch,
                        "best_acc": best_acc,
                    },
                    os.path.join(args.save_path, "encoders", "encoders_best_acc.pth"),
                )
            else:
                torch.save(
                    {
                        "BD_encoder_state_dict": BD_encoder.state_dict(),
                        "BD_encoder_ema_state_dict": BD_encoder_ema.state_dict(),
                        "M2C_encoder_state_dict": M2C_encoder.state_dict(),
                        "M2C_encoder_ema_state_dict": M2C_encoder_ema.state_dict(),
                        "logit_state_dict": logit.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch,
                        "best_acc": best_acc,
                    },
                    os.path.join(args.save_path, "encoders", "encoders_best_acc.pth"),
                )
            print(f"Epoch {epoch + 1}: Best model saved with acc {best_acc:.4f}")

        # Visualize the last validation batch
        num_samples = min(16, args.batch_size)
        fig, axes = plt.subplots(2, num_samples, figsize=(3 * num_samples, 6))
        for i in range(num_samples):
            axes[0, i].imshow(
                BD[i, 0].cpu().numpy().squeeze(), cmap="grey", vmin=0, vmax=1
            )
            axes[0, i].set_title("Original")
            axes[0, i].axis("off")

            axes[1, i].imshow(
                x_hat[i, 0].cpu().numpy().squeeze(), cmap="grey", vmin=0, vmax=1
            )
            axes[1, i].set_title("Reconstructed")
            axes[1, i].axis("off")
        plt.tight_layout()

        if args.use_tensorboard:
            writer.add_figure("ViTVAE Reconstructions", fig, global_step=epoch)

            writer.add_scalar("Train/Total", train_loss["total"], global_step=epoch)
            writer.add_scalar(
                "Train/Reconstruction", train_loss["recon"], global_step=epoch
            )
            writer.add_scalar(
                "Train/KL Divergence", train_loss["kl"], global_step=epoch
            )
            writer.add_scalar("Train/CLIP", train_loss["CLIP"], global_step=epoch)
            writer.add_scalar("Train/Top1 Acc", train_top1_acc, global_step=epoch)
            writer.add_scalar("Train/Top5 Acc", train_top5_acc, global_step=epoch)
            writer.add_scalar(
                "Train/Pos Cosine Sim", train_pos_cosine_sim, global_step=epoch
            )
            writer.add_scalar(
                "Train/Neg Cosine Sim", train_neg_cosine_sim, global_step=epoch
            )

            writer.add_scalar("Test/Total", test_loss["total"], global_step=epoch)
            writer.add_scalar(
                "Test/Reconstruction", test_loss["recon"], global_step=epoch
            )
            writer.add_scalar("Test/KL Divergence", test_loss["kl"], global_step=epoch)
            writer.add_scalar("Test/CLIP", test_loss["CLIP"], global_step=epoch)
            writer.add_scalar("Test/Top1 Acc", test_top1_acc, global_step=epoch)
            writer.add_scalar("Test/Top5 Acc", test_top5_acc, global_step=epoch)
            writer.add_scalar(
                "Test/Pos Cosine Sim", test_pos_cosine_sim, global_step=epoch
            )
            writer.add_scalar(
                "Test/Neg Cosine Sim", test_neg_cosine_sim, global_step=epoch
            )

            writer.add_scalar(
                "Others/Learning Rate", scheduler.get_last_lr()[0], global_step=epoch
            )
            writer.add_scalar(
                "Others/Logit Scale", current_logit_scale, global_step=epoch
            )

        plt.close(fig)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Train a ViT-VAE to encode and decode Band Diagram images."
    )

    # ViTVAE parameters
    parser.add_argument(
        "--patch_size", type=int, default=8, help="Patch size for the ViT model."
    )
    parser.add_argument(
        "--channels",
        type=int,
        nargs="+",
        default=[256, 128, 32],
        help="Channels for the ViT-VAE model.",
    )
    parser.add_argument(
        "--dim", type=int, default=512, help="Dimension of the ViT model."
    )
    parser.add_argument(
        "--out_dim", type=int, default=768, help="Output dimension of the ViT model."
    )
    parser.add_argument("--depth", type=int, default=6, help="Depth of the ViT model.")
    parser.add_argument(
        "--num_heads", type=int, default=16, help="Number of heads in the ViT model."
    )
    parser.add_argument(
        "--mlp_dim", type=int, default=1024, help="MLP dimension of the ViT model."
    )
    parser.add_argument(
        "--dropout", type=float, default=0.1, help="Dropout rate for the ViT model."
    )
    parser.add_argument(
        "--patch_dropout",
        type=float,
        default=0.1,
        help="Dropout rate for the patch embeddings in the ViT model.",
    )
    parser.add_argument(
        "--latent_channels",
        type=int,
        default=4,
        help="Number of channels in the latent space of the VAE.",
    )

    # M2C parameters
    parser.add_argument(
        "--m2c_latent_dim",
        type=int,
        default=128,
        help="Latent dimension for each layer in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_channels",
        nargs="+",
        type=int,
        default=[32, 64, 128],
        help="List of channels for the convolutional layers in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_num_vision_heads",
        type=int,
        default=8,
        help="Number of attention heads for the visual encoder in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_groups",
        type=int,
        default=8,
        help="Number of groups for the group normalization in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_vision_dropout",
        type=float,
        default=0.1,
        help="Dropout rate for the visual encoder in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_modality",
        type=str,
        choices=["pixmap", "spectral", "both"],
        default="pixmap",
        help="Input modality for the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_use_phase",
        action="store_true",
        help="Whether to use phase encoding in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_out_dim",
        type=int,
        default=768,
        help="Output dimension of the sequence encoder in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_num_heads",
        type=int,
        default=8,
        help="Number of attention heads for the sequence encoder in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_num_layers",
        type=int,
        default=6,
        help="Number of layers for the sequence encoder in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_d_ff",
        type=int,
        default=1024,
        help="Dimension of the feed forward layer in the sequence encoder in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_dropout",
        type=float,
        default=0.1,
        help="Dropout rate for the sequence encoder in the M2C encoder.",
    )
    parser.add_argument(
        "--m2c_trainable_pe",
        action="store_true",
        help="Whether to use trainable positional encoding in the M2C encoder.",
    )

    # Training parameters
    parser.add_argument(
        "--device", type=str, default="cuda", help="Device to use for training."
    )
    parser.add_argument(
        "--n_gpus", type=int, default=1, help="Number of GPUs to use for training."
    )
    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of training epochs."
    )
    parser.add_argument(
        "--warmup_epochs", type=int, default=20, help="Number of warmup epochs."
    )
    parser.add_argument(
        "--lr", type=float, default=1e-4, help="Learning rate for the optimizer."
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=2e-4,
        help="Weight decay for the optimizer.",
    )
    parser.add_argument(
        "--batch_size", type=int, default=16, help="Batch size for training."
    )
    parser.add_argument(
        "--num_workers", type=int, default=4, help="Number of workers for data loading."
    )
    parser.add_argument(
        "--beta_kl",
        type=float,
        default=1e-3,
        help="Weight for the KL divergence loss.",
    )
    parser.add_argument(
        "--use_kl_warmup",
        action="store_false",
        help="Whether to use KL warmup.",
    )
    parser.add_argument(
        "--beta_clip",
        type=float,
        default=1.0,
        help="Weight for the contrastive loss.",
    )
    parser.add_argument(
        "--clip_warmup",
        action="store_true",
        help="Whether to use CLIP warmup.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of gradient accumulation steps.",
    )
    parser.add_argument(
        "--VAE_reconstruction_loss",
        type=str,
        choices=["mse", "BCE"],
        default="mse",
        help="Loss function for VAE reconstruction loss.",
    )
    parser.add_argument(
        "--queue_size",
        type=int,
        default=65536,
        help="Size of the memory queue for MoCo-style contrastive learning.",
    )
    parser.add_argument(
        "--label_smoothing",
        type=float,
        default=0.0,
        help="Label smoothing for the contrastive loss.",
    )
    parser.add_argument(
        "--filter_negatives",
        action="store_true",
        help="Whether to filter out some negative samples.",
    )
    parser.add_argument(
        "--normalize_logits",
        action="store_true",
        help="Whether to normalize the logits.",
    )
    parser.add_argument(
        "--queue_warmup",
        action="store_true",
        help="Whether to disable CLIP loss during the first epoch to fill up the queue.",
    )
    parser.add_argument(
        "--ramp_ema",
        action="store_true",
        help="Whether to use ramped exponential moving average.",
    )

    # Dataset parameters
    parser.add_argument(
        "--data_path",
        type=str,
        default="../datasets/simple_structures/",
        help="Path to the dataset.",
    )
    parser.add_argument(
        "--train_size", type=int, default=0.8, help="Proportion of data for training."
    )
    parser.add_argument(
        "--max_seq_len", type=int, default=8, help="Maximum sequence length."
    )
    parser.add_argument(
        "--Nx", type=int, default=64, help="Spatial dimension of the input images."
    )
    parser.add_argument(
        "--scale_eps",
        action="store_false",
        help="Whether to scale the permittivity values.",
    )
    parser.add_argument(
        "--binarize_BD",
        action="store_false",
        help="Whether to binarize the band diagrams.",
    )
    parser.add_argument(
        "--clip_limit",
        type=float,
        default=1.0,
        help="Clip limit for CLAHE when binarizing the band diagram.",
    )
    parser.add_argument(
        "--threshold_size",
        type=int,
        default=51,
        help="Size of the neighborhood area for adaptive thresholding when binarizing the band diagram. Must be odd.",
    )
    parser.add_argument(
        "--kernel_size",
        type=int,
        default=5,
        help="Size of the kernel for Gaussian blur when processing the band diagram.",
    )
    parser.add_argument(
        "--random_air_layers",
        action="store_true",
        help="Whether to add random air layers as data augmentation.",
    )
    parser.add_argument(
        "--random_roll",
        action="store_true",
        help="Whether to apply random roll as data augmentation.",
    )

    # Saving parameters
    parser.add_argument(
        "--save_path",
        type=str,
        default=f"./checkpoints/{time.strftime('%Y%m%d_%H%M%S')}/",
        help="Path to save the trained model.",
    )

    # Misc
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility."
    )
    parser.add_argument(
        "--use_tqdm", action="store_true", help="Whether to use tqdm for progress bars."
    )
    parser.add_argument(
        "--use_tensorboard",
        action="store_true",
        help="Whether to use TensorBoard for logging.",
    )
    parser.add_argument(
        "--use_amp",
        action="store_true",
        help="Whether to use automatic mixed precision.",
    )

    args = parser.parse_args()
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
        # Dump the args to a text file
        with open(os.path.join(args.save_path, "args_CLIP.txt"), "w") as f:
            f.write(str(args))

    # assert (256 // args.patch_size) * (128 // args.patch_size) == args.dim, (
    #    "The dimension must be equal to (256//patch_size) * (128//patch_size)."
    # )
    assert torch.cuda.device_count() == args.n_gpus, (
        f"Expected {args.n_gpus} GPUs, but found {torch.cuda.device_count()}."
    )

    ViT_model = ViT(
        image_size=(256, 128),
        patch_size=args.patch_size,
        dim=args.dim,  # Should be equal to 256//8 * 128//8 = 32 * 16 = 512
        out_dim=args.out_dim,
        depth=args.depth,
        heads=args.num_heads,
        mlp_dim=args.mlp_dim,
        pool="cls",
        channels=1,
        dropout=args.dropout,
        emb_dropout=args.dropout,
        patch_dropout=args.patch_dropout,
    )
    print(f"ViT model parameters: {sum(p.numel() for p in ViT_model.parameters())}")

    BD_encoder = ViTVAE(
        vit=ViT_model,
        image_size=(256, 128),
        patch_size=args.patch_size,
        out_channels=1,
        latent_channels=args.latent_channels,
        channels=args.channels,
    )
    print(
        f"ViT-VAE encoder parameters: {sum(p.numel() for p in BD_encoder.encoder.parameters())} ;",
        f"ViT-VAE decoder parameters: {sum(p.numel() for p in BD_encoder.decoder.parameters())}",
    )

    M2C_encoder = M2CEncoder(
        spatial_dim=args.Nx,
        in_channels=1,
        latent_dim=args.m2c_latent_dim,
        channels=args.m2c_channels,
        num_vision_heads=args.m2c_num_vision_heads,
        groups=args.m2c_groups,
        vision_dropout=args.m2c_vision_dropout,
        modality=args.m2c_modality,
        use_phase=args.m2c_use_phase,
        out_dim=args.m2c_out_dim,
        max_seq_len=args.max_seq_len,
        num_heads=args.m2c_num_heads,
        num_layers=args.m2c_num_layers,
        d_ff=args.m2c_d_ff,
        dropout=args.m2c_dropout,
        kwargs_pe={
            "ini_freq_scale": 1.0,
            "tunable_freq_scale": True,
            "dropout": 0.0,
        },
        use_cls=True,
        trainable_pe=args.m2c_trainable_pe,
    )
    print(f"M2C encoder parameters: {sum(p.numel() for p in M2C_encoder.parameters())}")

    logit = LogitScale(init_value=0.07)
    print(f"Logit scale parameters: {sum(p.numel() for p in logit.parameters())}")

    train_step(BD_encoder, M2C_encoder, logit, args)
