import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, random_split
import argparse
import numpy as np
from tqdm.auto import tqdm
import wandb
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import os
from PIL import Image
import urllib.request
import zipfile
import shutil

from tasks.denoise import NoiseTransform
from tasks.metrics import batch_ssim, batch_psnr

from models.ae import AutoEncoder
from models.vae import VAE
from models.t3vae import T3VAE
from models.paretovae import ParetoVAE
from models.laplacevae import LaplaceVAE

import random


class FCLayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        activation: nn.Module = nn.LeakyReLU,
    ) -> None:
        super().__init__()

        self.lin = nn.Linear(input_dim, output_dim)
        if activation is None:
            self.activation = nn.Identity()
        else:
            self.activation = activation(inplace=True)

    def forward(self, x):
        return self.activation(self.lin(x))


class CleanNoiseDataset(Dataset):
    def __init__(self, dataset, transform_func=None):
        self.dataset = dataset
        self.transform_func = transform_func

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        origin, label = self.dataset[idx]

        if self.transform_func:
            noised = self.transform_func(origin.clone())
        else:
            noised = origin.clone()

        return origin, noised, label


class OmniglotDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.alphabet_classes = []

        all_alphabets_dir = [
            os.path.join(root_dir, "omniglot", "images_background"),
            os.path.join(root_dir, "omniglot", "images_evaluation"),
        ]

        for split_dir in all_alphabets_dir:
            if not os.path.exists(split_dir):
                raise FileNotFoundError(f"Not found: {split_dir}")

            alphabet_names = sorted(
                [
                    d
                    for d in os.listdir(split_dir)
                    if os.path.isdir(os.path.join(split_dir, d))
                ]
            )
            for alphabet_name in alphabet_names:
                if alphabet_name not in self.alphabet_classes:
                    self.alphabet_classes.append(alphabet_name)

        alphabet_to_label = {
            name: i for i, name in enumerate(self.alphabet_classes)
        }

        for split_dir in all_alphabets_dir:
            for alphabet_name in os.listdir(split_dir):
                if os.path.isdir(os.path.join(split_dir, alphabet_name)):
                    alphabet_path = os.path.join(split_dir, alphabet_name)
                    label = alphabet_to_label[alphabet_name]

                    for char_name in os.listdir(alphabet_path):
                        char_path = os.path.join(alphabet_path, char_name)
                        for img_name in os.listdir(char_path):
                            img_path = os.path.join(char_path, img_name)
                            self.data.append((img_path, label))
        self.data = random.sample(self.data, len(self.data))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert("L")
        if self.transform:
            image = self.transform(image)
        return image, label


def main(args):
    set_random_seed(args.seed)
    wandb.init(
        project="iclr2026",
        group="image_denoising",
        config={
            "model": args.model_name,
            "loss_type": args.loss_type,
            "normalization": args.normalization,
            "epochs": args.epochs,
            "activation": args.activation,
            "optimizer": args.optimizer,
            "dropout_rate": args.dropout_rate,
            "nu": args.nu,
            "noise_type": args.noise_type,
            "noise_prob": args.noise_prob,
            "salt_prob": args.salt_prob,
            "gaussian_mean": args.gaussian_mean,
            "gaussian_std": args.gaussian_std,
            "anneal_function": args.anneal_function,
            "batch_size": args.batch_size,
            "learning_rate": args.learning_rate,
            "dataset": args.dataset,
            "seed": args.seed,
        },
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = args.dataset.upper()

    if args.model_name == "vae":
        MODEL = VAE
    elif args.model_name == "pareto":
        MODEL = ParetoVAE
    elif args.model_name == "laplace":
        MODEL = LaplaceVAE
    elif args.model_name == "t3":
        MODEL = T3VAE
    elif args.model_name == "ae":
        MODEL = AutoEncoder
    else:
        raise ValueError(f"Unknown model type: {args.model_name}")

    base_transforms = []

    if dataset == "MNIST" or dataset == "FASHIONMNIST":
        latent_dim = 16
        base_transforms.extend(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
            ]
        )
    elif dataset == "OMNIGLOT":
        latent_dim = 32
        base_transforms.extend(
            [
                transforms.Resize((64, 64)),
                transforms.ToTensor(),
            ]
        )
    elif dataset == "CIFAR10" or dataset == "SVHN":
        latent_dim = 64
        base_transforms.extend(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
            ]
        )
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    clean_transform = transforms.Compose(base_transforms)

    noise_kwargs = {}
    if args.noise_type == "salt_pepper":
        noise_kwargs = {"salt_prob": args.salt_prob}
    elif args.noise_type == "gaussian":
        noise_kwargs = {
            "mean": args.gaussian_mean,
            "std": args.gaussian_std,
        }
    elif args.noise_type == "mixed":
        noise_kwargs = {
            "noise_types": ["salt_pepper", "gaussian"],
            "noise_probs": [0.5, 0.5],
        }
    else:
        raise ValueError(f"Unknown noise type: {args.noise_type}")

    base_train_dataset = create_dataset(dataset, "train", clean_transform)
    base_test_dataset = create_dataset(dataset, "test", clean_transform)

    transform_func = NoiseTransform(
        noise_type=args.noise_type,
        noise_prob=args.noise_prob,
        **noise_kwargs,
    )

    train_paired_dataset = CleanNoiseDataset(
        base_train_dataset, transform_func
    )
    test_paired_dataset = CleanNoiseDataset(base_test_dataset, transform_func)

    if dataset == "OMNIGLOT":
        num_classes = 50
    else:
        # MNIST, SVHN, FashionMNIST, CIFAR10
        num_classes = 10

    train_loader = DataLoader(
        train_paired_dataset, batch_size=args.batch_size, shuffle=True
    )
    test_loader = DataLoader(
        test_paired_dataset, batch_size=args.batch_size, shuffle=False
    )

    data_sample, _, _ = train_paired_dataset[0]

    model = MODEL(
        nu=args.nu,
        input_shape=data_sample.shape,
        latent_dim=latent_dim,
        reconstruction=args.loss_type,
        device=device,
        normalization=args.normalization,
        activation=args.activation,
        dropout_rate=args.dropout_rate,
    ).to(device)

    latent_classifier = FCLayer(latent_dim, num_classes, None).to(device)
    latent_clf_criterion = nn.CrossEntropyLoss()

    pretrained_classifier = load_pretrained_classifier(dataset, device)
    pretrained_classifier.eval()
    for param in pretrained_classifier.parameters():
        param.requires_grad = False

    if args.optimizer == "adam":
        latent_clf_optimizer = optim.Adam(
            latent_classifier.parameters(),
            lr=args.learning_rate,
        )
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    elif args.optimizer == "adamw":
        latent_clf_optimizer = optim.AdamW(
            latent_classifier.parameters(), lr=args.learning_rate
        )
        optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

    elif args.optimizer == "adam-wd":
        latent_clf_optimizer = optim.Adam(
            latent_classifier.parameters(),
            lr=args.learning_rate,
            weight_decay=1e-5,
        )
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate, weight_decay=1e-5
        )

    elif args.optimizer == "adamw-wd":
        latent_clf_optimizer = optim.AdamW(
            latent_classifier.parameters(),
            lr=args.learning_rate,
            weight_decay=1e-5,
        )
        optimizer = optim.AdamW(
            model.parameters(), lr=args.learning_rate, weight_decay=1e-5
        )
    else:
        raise ValueError("Unsupported optimizer type")

    wandb.watch(model, log_freq=100)

    step = 0
    k = 0.0025
    if args.anneal_function == "linear":
        x0 = len(train_loader.dataset) / args.batch_size * args.epochs
    else:
        step = 10
        x0 = len(train_loader.dataset) / args.batch_size * args.epochs * 0.25

    best_test_acc = 0.0
    best_test_loss = float("inf")
    best_test_reconstruction = float("inf")
    best_ssim = 0.0
    best_psnr = 0.0
    best_epoch = 0

    logging = False

    for epoch in range(1, args.epochs + 1):
        if best_epoch + 20 < epoch:
            print("Early stopping as no improvement for 10 epochs.")
            break

        model.train()
        latent_classifier.train()
        epoch_loss = 0
        epoch_rec_loss = 0
        epoch_reg_loss = 0

        train_loop = tqdm(
            train_loader,
            desc=f"Train Epoch {epoch}/{args.epochs}",
            leave=False,
        )

        latent_clf_acc = 0
        n_samples = 0
        for clean_x, noisy_x, labels in train_loop:
            clean_x = clean_x.to(device)
            noisy_x = noisy_x.to(device)
            labels = labels.to(device)

            n_samples += noisy_x.size(0)

            if args.anneal_function == "linear":
                regularization_weight = 2 * epoch / args.epochs
            else:
                regularization_weight = anneal_function(
                    args.anneal_function, step, k, x0, args.epochs, start=0.01
                )
            if args.model_name == "laplace":
                regularization_weight = regularization_weight / 10
            optimizer.zero_grad()

            z, x_hat, derivatives = model(noisy_x)

            recon_loss = model.reconstruction(x_hat, clean_x)
            reg_loss = model.regularization(derivatives)
            loss = recon_loss + regularization_weight * reg_loss

            if torch.isnan(loss).any():
                print(
                    f"NaN loss detected at epoch {epoch}. "
                    "Stopping training."
                )
                wandb.log({"nan_detected": True, "epoch_nan": epoch})
                wandb.finish()
                return

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * noisy_x.size(0)
            epoch_rec_loss += recon_loss.item() * noisy_x.size(0)
            epoch_reg_loss += reg_loss.item() * noisy_x.size(0)

            latent_clf_logits = latent_classifier(z.detach())
            latent_clf_loss = latent_clf_criterion(latent_clf_logits, labels)
            _, preds = torch.max(latent_clf_logits, 1)
            latent_clf_acc += (preds == labels).sum().item()
            latent_clf_optimizer.zero_grad()
            latent_clf_loss.backward()
            latent_clf_optimizer.step()

            train_loop.set_postfix(
                {
                    "acc": f"{(latent_clf_acc / n_samples):.4f}",
                    "recon_loss": f"{recon_loss.item():.4f}",
                    "regularization_loss": f"{reg_loss:.4f}",
                    "regularization_weight": f"{regularization_weight:.4f}",
                }
            )
            step += 1

        wandb.log(
            {
                "train/accuracy": latent_clf_acc / n_samples,
                "train/reconstruction_loss": epoch_rec_loss / n_samples,
                "train/regularization_loss": epoch_reg_loss / n_samples,
                "train/regularization_weight": regularization_weight,
            },
            step=epoch,
        )

        # 테스트 평가
        model.eval()
        latent_classifier.eval()
        with torch.no_grad():
            clf_test_loss = 0
            clf_test_acc = 0
            test_n_samples = 0
            recon_total = 0

            clf_reconstructed_acc = 0
            consistency = 0

            test_loop = tqdm(
                test_loader,
                desc=f"Test Epoch {epoch}/{args.epochs}",
                leave=False,
            )
            ssim_score = 0
            psnr_score = 0
            for clean_x, noisy_x, labels in test_loop:
                clean_x = clean_x.to(device)
                noisy_x = noisy_x.to(device)
                labels = labels.to(device)

                z_noisy, x_hat, _ = model(noisy_x)

                recon_loss = model.reconstruction(x_hat, clean_x)
                logits_noisy = latent_classifier(z_noisy)
                loss = latent_clf_criterion(logits_noisy, labels)
                clf_test_loss += loss.item() * noisy_x.size(0)
                _, preds_noisy = torch.max(logits_noisy, 1)
                clf_test_acc += (preds_noisy == labels).sum().item()

                logits_clean = pretrained_classifier(clean_x)
                _, preds_clean = torch.max(logits_clean, 1)

                logits_reconstructed = pretrained_classifier(x_hat)
                _, preds_reconstructed = torch.max(logits_reconstructed, 1)

                clf_reconstructed_acc += (
                    (preds_reconstructed == labels).sum().item()
                )

                consistency += (
                    (preds_clean == preds_reconstructed).sum().item()
                )

                test_n_samples += clean_x.size(0)

                recon_total += recon_loss.item() * clean_x.size(0)
                clean_batch_cpu = clean_x.cpu()
                denoised_batch_cpu = x_hat.cpu()

                ssim_score += batch_ssim(clean_batch_cpu, denoised_batch_cpu)
                psnr_score += batch_psnr(clean_batch_cpu, denoised_batch_cpu)

            avg_ssim = ssim_score / len(test_loader)
            avg_psnr = psnr_score / len(test_loader)

            avg_test_loss = clf_test_loss / test_n_samples
            avg_test_acc = clf_test_acc / test_n_samples
            avg_test_reconstruction = recon_total / test_n_samples

            test_avg_reconstructed_acc = clf_reconstructed_acc / test_n_samples
            test_avg_consistency = consistency / test_n_samples

            wandb.log(
                {
                    "test/loss": avg_test_loss,
                    "test/accuracy": avg_test_acc,
                    "test/reconstruction_loss": avg_test_reconstruction,
                    "test/reconstructed_accuracy": test_avg_reconstructed_acc,
                    "test/consistency": test_avg_consistency,
                    "test/ssim": avg_ssim,
                    "test/psnr": avg_psnr,
                },
                step=epoch,
            )

            if avg_test_acc > best_test_acc:
                best_test_acc = avg_test_acc
                wandb.log(
                    {
                        "best/test_accuracy": avg_test_acc,
                        "best/ssim": avg_ssim,
                        "best/psnr": avg_psnr,
                    },
                    step=epoch,
                )
            if avg_test_loss < best_test_loss:
                best_test_loss = avg_test_loss
                wandb.log(
                    {
                        "best/test_accuracy": avg_test_acc,
                        "best/ssim": avg_ssim,
                        "best/psnr": avg_psnr,
                    },
                    step=epoch,
                )
            if avg_test_reconstruction < best_test_reconstruction:
                best_test_reconstruction = avg_test_reconstruction
                wandb.log(
                    {
                        "best/test_accuracy": avg_test_acc,
                        "best/ssim": avg_ssim,
                        "best/psnr": avg_psnr,
                    },
                    step=epoch,
                )
                best_epoch = epoch
            if avg_ssim > best_ssim:
                best_ssim = avg_ssim
                logging = True
                wandb.log(
                    {
                        "best/test_accuracy": avg_test_acc,
                        "best/ssim": avg_ssim,
                        "best/psnr": avg_psnr,
                    },
                    step=epoch,
                )
                best_epoch = epoch
            if avg_psnr > best_psnr:
                best_psnr = avg_psnr
                logging = True
                wandb.log(
                    {
                        "best/test_accuracy": avg_test_acc,
                        "best/ssim": avg_ssim,
                        "best/psnr": avg_psnr,
                    },
                    step=epoch,
                )
                best_epoch = epoch

        if logging or epoch == args.epochs or epoch % 5 == 0:
            logging = False
            model.eval()
            with torch.no_grad():
                for i, (clean_batch, noisy_batch, label_batch) in enumerate(
                    test_loader
                ):
                    if i >= 3:
                        break
                    noisy_batch = noisy_batch.to(device)
                    _, denoised_batch, _ = model(noisy_batch)

                    clean_batch_cpu = clean_batch.cpu()
                    noisy_batch_cpu = noisy_batch.cpu()
                    denoised_batch_cpu = denoised_batch.cpu()

                    log_name = (
                        f"{args.model_name}-{args.loss_type}-{args.noise_prob}\n"
                        f"Epoch{epoch}\n"
                        f"ReconAcc:{test_avg_reconstructed_acc:.4f}\n"
                        f"Consistency: {test_avg_consistency:.3f}"
                    )

                    denoise_fig = plot_denoising_results(
                        clean_batch_cpu,
                        noisy_batch_cpu,
                        denoised_batch_cpu,
                        epoch,
                        n_samples=32,
                        ssim_scores=avg_ssim,
                        psnr_scores=avg_psnr,
                        log_name=log_name,
                    )

                    try:
                        wandb.log(
                            {
                                f"denoising_results_{i}": wandb.Image(
                                    denoise_fig
                                )
                            },
                            step=epoch,
                        )
                    except Exception as e:
                        print(f"wandb error: {e}")

                    plt.close(denoise_fig)

            torch.cuda.empty_cache()
        print(
            f"Epoch {epoch}/{args.epochs} - "
            f"Test Acc: {avg_test_acc:.4f}, "
            f"Recon Acc: {test_avg_reconstructed_acc:.4f}, "
            f"Consistency: {test_avg_consistency:.4f}, "
            f"SSIM: {avg_ssim:.4f}, "
            f"PSNR: {avg_psnr:.2f} dB"
        )

    print(f"Training completed! Best Test Accuracy: {best_test_acc:.4f}")
    print(f"Final Reconstructed Accuracy: {test_avg_reconstructed_acc:.4f}")
    print(f"Final Consistency: {test_avg_consistency:.4f}")
    print(f"Best SSIM: {best_ssim:.4f}, Best PSNR: {best_psnr:.2f} dB")

    wandb.finish()


def set_random_seed(seed=42):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # For deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def download_and_extract_omniglot(root_dir):
    omniglot_url = "https://github.com/brendenlake/omniglot/raw/master/python"
    zip_files = ["images_background.zip", "images_evaluation.zip"]

    dataset_dir = os.path.join(root_dir, "omniglot")
    os.makedirs(dataset_dir, exist_ok=True)

    for zip_file in zip_files:
        zip_path = os.path.join(dataset_dir, zip_file)

        if not os.path.exists(zip_path):
            url = f"{omniglot_url}/{zip_file}"
            urllib.request.urlretrieve(url, zip_path)

        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(dataset_dir)

    for folder in ["images_background", "images_evaluation"]:
        src = os.path.join(dataset_dir, "omniglot", folder)
        dst = os.path.join(dataset_dir, folder)
        if os.path.exists(src):
            shutil.move(src, dst)
    if os.path.exists(os.path.join(dataset_dir, "omniglot")):
        shutil.rmtree(os.path.join(dataset_dir, "omniglot"))


def anneal_function(anneal_fn, step, k, x0, epochs, start=0.001):
    try:
        constant_value = float(anneal_fn)
        return constant_value
    except ValueError:
        pass

    if anneal_fn == "logistic":
        return start + float(1 / (1 + np.exp(-k * (step - x0))))
    elif anneal_fn == "cosine":
        if step < x0:
            return start + 0.5 * (1 - np.cos(np.pi * step / x0))
        else:
            return start + 1.0
    elif anneal_fn == "cyclic":
        cycle_length = x0 / 4
        progress = (step % cycle_length) / cycle_length
        return start + 0.5 * (1 - np.cos(np.pi * progress))
    else:
        raise ValueError(f"Unknown annealing function: {anneal_fn}")


def load_pretrained_classifier(dataset, device):
    if dataset == "MNIST":
        model = torch.jit.load("checkpoints/mnist_9976.pt").to(device)
        model.eval()
    elif dataset == "SVHN":
        model = torch.jit.load("checkpoints/svhn_9611.pt").to(device)
        model.eval()
    elif dataset == "OMNIGLOT":
        model = torch.jit.load("checkpoints/omniglot_8497.pt").to(device)
        model.eval()
    elif dataset == "CIFAR10":
        model = torch.jit.load("checkpoints/cifar10_8452.pt").to(device)
        model.eval()
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    return model


def create_dataset(dataset, split, transform):
    if dataset.upper() == "OMNIGLOT":
        data_root = "./data"
        full_dataset = OmniglotDataset(root_dir=data_root, transform=transform)
        train_size = int(0.9 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_dataset, val_dataset = random_split(
            full_dataset, [train_size, val_size]
        )
        if split == "train":
            return train_dataset
        else:
            return val_dataset

    elif dataset == "MNIST":
        is_train = split == "train"
        return datasets.MNIST(
            root="./data",
            train=is_train,
            download=True,
            transform=transform,
        )
    elif dataset == "SVHN":
        return datasets.SVHN(
            root="./data",
            split=split,
            download=True,
            transform=transform,
        )
    elif dataset == "FASHIONMNIST":
        is_train = split == "train"
        return datasets.FashionMNIST(
            root="./data",
            train=is_train,
            download=True,
            transform=transform,
        )
    elif dataset == "CIFAR10":
        is_train = split == "train"
        return datasets.CIFAR10(
            root="./data",
            train=is_train,
            download=True,
            transform=transform,
        )
    else:
        raise ValueError(f"Unknown dataset: {dataset}")


def plot_denoising_results(
    original,
    noisy,
    reconstructed,
    epoch,
    n_samples=8,
    ssim_scores=None,
    psnr_scores=None,
    log_name="",
):
    original = original[:n_samples]
    noisy = noisy[:n_samples]
    reconstructed = reconstructed[:n_samples]

    if n_samples == 32:
        rows, cols = 2, 16
        fig, axes = plt.subplots(3, 1, figsize=(24, 10))

        for type_idx, (imgs, type_name) in enumerate(
            [
                (original, "Original"),
                (noisy, "Noisy"),
                (reconstructed, "Denoised"),
            ]
        ):
            grid = make_grid(imgs, nrow=cols, padding=2, normalize=False)
            grid_np = grid.permute(1, 2, 0).numpy()

            if grid_np.shape[2] == 1:
                grid_np = grid_np.squeeze(2)
                axes[type_idx].imshow(grid_np, cmap="gray")
            else:
                axes[type_idx].imshow(grid_np)

            axes[type_idx].set_title(
                f"{type_name} Images", fontsize=14, fontweight="bold"
            )
            axes[type_idx].axis("off")

        if ssim_scores is not None and psnr_scores is not None:
            avg_ssim = (
                np.mean(ssim_scores)
                if isinstance(ssim_scores, (list, np.ndarray))
                else ssim_scores
            )
            avg_psnr = (
                np.mean(psnr_scores)
                if isinstance(psnr_scores, (list, np.ndarray))
                else psnr_scores
            )
            title = (
                log_name
                + f"\nAvg SSIM: {avg_ssim:.4f}, Avg PSNR: {avg_psnr:.2f} dB"
            )

        fig.suptitle(title, fontsize=16, fontweight="bold")
        plt.tight_layout()

    else:
        comparison = torch.cat([original, noisy, reconstructed], dim=0)
        grid = make_grid(
            comparison, nrow=n_samples, padding=2, normalize=False
        )

        fig, ax = plt.subplots(figsize=(15, 12))

        grid_np = grid.permute(1, 2, 0).numpy()

        if grid_np.shape[2] == 1:
            grid_np = grid_np.squeeze(2)
            ax.imshow(grid_np, cmap="gray")
        else:
            ax.imshow(grid_np)

        title = f"Denoising Results - Epoch {epoch}\n"
        title += "Row 1: Original Images\n"
        title += "Row 2: Noisy Images\n"
        title += "Row 3: Denoised Images"

        if ssim_scores is not None and psnr_scores is not None:
            avg_ssim = (
                np.mean(ssim_scores)
                if isinstance(ssim_scores, (list, np.ndarray))
                else ssim_scores
            )
            avg_psnr = (
                np.mean(psnr_scores)
                if isinstance(psnr_scores, (list, np.ndarray))
                else psnr_scores
            )
            title = (
                log_name
                + f"\nAvg SSIM: {avg_ssim:.4f}, Avg PSNR: {avg_psnr:.2f} dB"
            )

        ax.set_title(title, fontsize=14)
        ax.axis("off")

    return fig


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Image Denoising Training with SSIM/PSNR Evaluation"
    )

    parser.add_argument(
        "--dataset", type=str, default="mnist", help="Dataset name"
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="ae",
        choices=["ae", "vae", "t3", "pareto", "laplace"],
        help="Model type",
    )
    parser.add_argument(
        "--loss-type",
        type=str,
        default="l1",
        choices=["l1", "l2", "mae", "mse"],
        help="Reconstruction loss type",
    )

    parser.add_argument(
        "--nu",
        type=float,
        default=3.1,
        help="Nu parameter for Pareto VAE models",
    )

    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch-size", type=int, default=512)
    parser.add_argument(
        "--anneal-function",
        type=str,
        default="logistic",
        help="Annealing function for VAE regularization weight",
    )
    parser.add_argument(
        "--normalization",
        type=str,
        default="none",
        choices=[
            "batchnorm",
            "layernorm",
            "instancenorm",
            "groupnorm",
            "none",
        ],
        help="Normalization technique to use",
    )
    parser.add_argument(
        "--activation",
        type=str,
        default="relu",
        choices=["relu", "leaky_relu", "elu", "selu", "gelu", "swish", "tanh"],
        help="Activation function to use",
    )
    parser.add_argument(
        "--dropout-rate",
        type=float,
        default=0.0,
        help="Dropout rate (0.0 = no dropout)",
    )

    parser.add_argument(
        "--noise-type",
        type=str,
        default="salt_pepper",
        choices=["salt_pepper", "gaussian", "mixed"],
        help="Type of noise to add",
    )
    parser.add_argument(
        "--noise-prob",
        type=float,
        default=0.1,
        help="Probability of noise (0.0-1.0)",
    )

    parser.add_argument(
        "--salt-prob",
        type=float,
        default=0.5,
        help="Probability of salt vs pepper in salt_pepper noise",
    )

    parser.add_argument(
        "--gaussian-mean",
        type=float,
        default=0.0,
        help="Mean of gaussian noise",
    )
    parser.add_argument(
        "--gaussian-std",
        type=float,
        default=0.1,
        help="Standard deviation of gaussian noise",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        default="adam",
        choices=["adam", "adamw", "adam-wd", "adamw-wd", "sgd"],
        help="Optimizer to use",
    )
    parser.add_argument(
        "--learning-rate", type=float, default=1e-3, help="Learning rate"
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    args = parser.parse_args()
    main(args)
