import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import argparse
import numpy as np
from tqdm.auto import tqdm
import wandb
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from tasks.denoise import NoiseTransform
from tasks.metrics import batch_ssim, batch_psnr

from models.ae import AutoEncoder
from models.vae import VAE
from models.t3vae import T3VAE
from models.paretovae import ParetoVAE
from models.laplacevae import LaplaceVAE

import random


class CelebADataset(Dataset):
    def __init__(
        self,
        root="./data",
        split="train",
        attr_list=None,
        transform=None,
        download=True,
        target_type="attr",
    ):
        self.dataset = datasets.CelebA(
            root=root,
            split=split,
            target_type=target_type,
            transform=None,
            download=download,
        )

        self.transform = transform
        self.attr_list = attr_list
        self.attr_names = self.dataset.attr_names

        if attr_list is not None:
            self.attr_indices = [
                self.attr_names.index(attr) for attr in attr_list
            ]
        else:
            self.attr_indices = list(range(len(self.attr_names)))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, target = self.dataset[idx]

        if self.transform:
            img = self.transform(img)

        return img

    def get_attr_names(self):
        return self.attr_names

    def get_selected_attr_names(self):
        return [self.attr_names[i] for i in self.attr_indices]


class CleanNoiseDataset(Dataset):
    def __init__(self, dataset, transform_func=None):
        self.dataset = dataset
        self.transform_func = transform_func

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img = self.dataset[idx]
        if self.transform_func:
            noised = self.transform_func(img.clone())
        else:
            noised = img.clone()

        return img, noised


def main(args):
    set_random_seed(args.seed)
    wandb.init(
        project="iclr2026",
        group="image_denoising",
        config={
            "model": args.model_name,
            "loss_type": args.loss_type,
            "normalization": args.normalization,
            "epochs": args.epochs,
            "activation": args.activation,
            "optimizer": args.optimizer,
            "dropout_rate": args.dropout_rate,
            "nu": args.nu,
            "noise_type": args.noise_type,
            "noise_prob": args.noise_prob,
            "salt_prob": args.salt_prob,
            "gaussian_mean": args.gaussian_mean,
            "gaussian_std": args.gaussian_std,
            "anneal_function": args.anneal_function,
            "batch_size": args.batch_size,
            "learning_rate": args.learning_rate,
            "dataset": "celeba",
            "seed": args.seed,
        },
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.model_name == "vae":
        MODEL = VAE
    elif args.model_name == "pareto":
        MODEL = ParetoVAE
    elif args.model_name == "laplace":
        MODEL = LaplaceVAE
    elif args.model_name == "t3":
        MODEL = T3VAE
    elif args.model_name == "ae":
        MODEL = AutoEncoder
    else:
        raise ValueError(f"Unknown model type: {args.model_name}")

    latent_dim = 128
    clean_transform = transforms.Compose(
        [
            transforms.CenterCrop(148),
            transforms.Resize(64),
            transforms.ToTensor(),
        ]
    )

    noise_kwargs = {}
    if args.noise_type == "salt_pepper":
        noise_kwargs = {"salt_prob": args.salt_prob}
    elif args.noise_type == "gaussian":
        noise_kwargs = {
            "mean": args.gaussian_mean,
            "std": args.gaussian_std,
        }
    elif args.noise_type == "mixed":
        noise_kwargs = {
            "noise_types": ["salt_pepper", "gaussian"],
            "noise_probs": [0.5, 0.5],
        }
    else:
        raise ValueError(f"Unknown noise type: {args.noise_type}")

    base_train_dataset, base_valid_dataset, base_test_dataset = (
        create_celeba_dataset("train", transform=clean_transform)
    )

    transform_func = NoiseTransform(
        noise_type=args.noise_type,
        noise_prob=args.noise_prob,
        **noise_kwargs,
    )

    train_paired_dataset = CleanNoiseDataset(
        base_train_dataset, transform_func
    )
    test_paired_dataset = CleanNoiseDataset(base_test_dataset, transform_func)

    train_loader = DataLoader(
        train_paired_dataset, batch_size=args.batch_size, shuffle=True
    )
    test_loader = DataLoader(
        test_paired_dataset, batch_size=args.batch_size, shuffle=False
    )

    data_sample, _ = train_paired_dataset[0]
    print("Data sample shape:", data_sample.shape)
    print("Data sample shape (noisy):", data_sample.shape)

    model = MODEL(
        nu=args.nu,
        input_shape=data_sample.shape,
        latent_dim=latent_dim,
        reconstruction=args.loss_type,
        device=device,
        normalization=args.normalization,
        activation=args.activation,
        dropout_rate=args.dropout_rate,
    ).to(device)

    if args.optimizer == "adam":
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    elif args.optimizer == "adamw":
        optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

    elif args.optimizer == "adam-wd":
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate, weight_decay=1e-5
        )

    elif args.optimizer == "adamw-wd":
        optimizer = optim.AdamW(
            model.parameters(), lr=args.learning_rate, weight_decay=1e-5
        )
    else:
        raise ValueError("Unsupported optimizer type")

    wandb.watch(model, log_freq=100)

    step = 0
    k = 0.0025
    if args.anneal_function == "linear":
        x0 = len(train_loader.dataset) / args.batch_size * args.epochs
    else:
        x0 = len(train_loader.dataset) / args.batch_size * args.epochs * 0.3

    best_test_reconstruction = float("inf")
    best_ssim = 0.0
    best_psnr = 0.0
    best_epoch = 0

    logging = False

    for epoch in range(1, args.epochs + 1):
        if best_epoch + 10 < epoch:
            print("Early stopping as no improvement for 10 epochs.")
            break

        model.train()
        epoch_loss = 0
        epoch_rec_loss = 0
        epoch_reg_loss = 0

        train_loop = tqdm(
            train_loader,
            desc=f"Train Epoch {epoch}/{args.epochs}",
            leave=False,
        )

        n_samples = 0
        for clean_x, noisy_x in train_loop:
            clean_x = clean_x.to(device)
            noisy_x = noisy_x.to(device)

            n_samples += noisy_x.size(0)

            if args.anneal_function == "linear":
                regularization_weight = 2 * epoch / args.epochs
            else:
                regularization_weight = anneal_function(
                    args.anneal_function, step, k, x0, args.epochs, start=0.01
                )
            if args.model_name == "laplace":
                regularization_weight = regularization_weight / 10
            optimizer.zero_grad()

            z, x_hat, derivatives = model(noisy_x)

            recon_loss = model.reconstruction(x_hat, clean_x)
            reg_loss = model.regularization(derivatives)
            loss = recon_loss + regularization_weight * reg_loss

            if torch.isnan(loss).any():
                print(
                    f"NaN loss detected at epoch {epoch}. "
                    "Stopping training."
                )
                wandb.log({"nan_detected": True, "epoch_nan": epoch})
                wandb.finish()
                return

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * noisy_x.size(0)
            epoch_rec_loss += recon_loss.item() * noisy_x.size(0)
            epoch_reg_loss += reg_loss.item() * noisy_x.size(0)

            train_loop.set_postfix(
                {
                    "recon_loss": f"{recon_loss.item():.4f}",
                    "regularization_loss": f"{reg_loss:.4f}",
                    "regularization_weight": f"{regularization_weight:.4f}",
                }
            )
            step += 1

        wandb.log(
            {
                "train/reconstruction_loss": epoch_rec_loss / n_samples,
                "train/regularization_loss": epoch_reg_loss / n_samples,
                "train/regularization_weight": regularization_weight,
            },
            step=epoch,
        )

        model.eval()
        with torch.no_grad():
            test_n_samples = 0
            recon_total = 0

            clf_reconstructed_acc = 0
            consistency = 0

            test_loop = tqdm(
                test_loader,
                desc=f"Test Epoch {epoch}/{args.epochs}",
                leave=False,
            )
            ssim_score = 0
            psnr_score = 0
            for clean_x, noisy_x in test_loop:
                clean_x = clean_x.to(device)
                noisy_x = noisy_x.to(device)

                z_noisy, x_hat, _ = model(noisy_x)

                recon_loss = model.reconstruction(x_hat, clean_x)
                test_n_samples += clean_x.size(0)

                recon_total += recon_loss.item() * clean_x.size(0)
                clean_batch_cpu = clean_x.cpu()
                denoised_batch_cpu = x_hat.cpu()

                ssim_score += batch_ssim(clean_batch_cpu, denoised_batch_cpu)
                psnr_score += batch_psnr(clean_batch_cpu, denoised_batch_cpu)

            avg_ssim = ssim_score / len(test_loader)
            avg_psnr = psnr_score / len(test_loader)

            avg_test_reconstruction = recon_total / test_n_samples

            test_avg_reconstructed_acc = clf_reconstructed_acc / test_n_samples
            test_avg_consistency = consistency / test_n_samples

            wandb.log(
                {
                    "test/reconstruction_loss": avg_test_reconstruction,
                    "test/reconstructed_accuracy": test_avg_reconstructed_acc,
                    "test/consistency": test_avg_consistency,
                    "test/ssim": avg_ssim,
                    "test/psnr": avg_psnr,
                },
                step=epoch,
            )

            if avg_test_reconstruction < best_test_reconstruction:
                best_test_reconstruction = avg_test_reconstruction
                wandb.log(
                    {
                        "best/test_reconstruction": avg_test_reconstruction,
                        "best/ssim": avg_ssim,
                        "best/psnr": avg_psnr,
                    },
                    step=epoch,
                )
            if avg_ssim > best_ssim:
                best_ssim = avg_ssim
                logging = True
                wandb.log(
                    {
                        "best/test_reconstruction": avg_test_reconstruction,
                        "best/ssim": avg_ssim,
                        "best/psnr": avg_psnr,
                    },
                    step=epoch,
                )
                best_epoch = epoch
            if avg_psnr > best_psnr:
                best_psnr = avg_psnr
                logging = True
                wandb.log(
                    {
                        "best/test_reconstruction": avg_test_reconstruction,
                        "best/ssim": avg_ssim,
                        "best/psnr": avg_psnr,
                    },
                    step=epoch,
                )
                best_epoch = epoch

        if logging or epoch == args.epochs or epoch % 5 == 0:
            logging = False
            model.eval()
            with torch.no_grad():
                for i, (clean_batch, noisy_batch) in enumerate(test_loader):
                    if i >= 3:
                        break
                    noisy_batch = noisy_batch.to(device)
                    _, denoised_batch, _ = model(noisy_batch)

                    clean_batch_cpu = clean_batch.cpu()
                    noisy_batch_cpu = noisy_batch.cpu()
                    denoised_batch_cpu = denoised_batch.cpu()

                    log_name = (
                        f"{args.model_name}-{args.loss_type}-{args.noise_prob}\n"
                        f"Epoch{epoch}\n"
                        f"ReconAcc:{test_avg_reconstructed_acc:.4f}\n"
                        f"Consistency: {test_avg_consistency:.3f}"
                    )

                    denoise_fig = plot_denoising_results(
                        clean_batch_cpu,
                        noisy_batch_cpu,
                        denoised_batch_cpu,
                        epoch,
                        n_samples=32,
                        ssim_scores=avg_ssim,
                        psnr_scores=avg_psnr,
                        log_name=log_name,
                    )

                    try:
                        wandb.log(
                            {
                                f"denoising_results_{i}": wandb.Image(
                                    denoise_fig
                                )
                            },
                            step=epoch,
                        )
                    except Exception as e:
                        print(f"wandb error: {e}")

                    plt.close(denoise_fig)

                torch.cuda.empty_cache()

        print(
            f"Epoch {epoch}/{args.epochs} - "
            f"Recon Acc: {test_avg_reconstructed_acc:.4f}, "
            f"Consistency: {test_avg_consistency:.4f}, "
            f"SSIM: {avg_ssim:.4f}, "
            f"PSNR: {avg_psnr:.2f} dB"
        )

    print(f"Final Reconstructed Accuracy: {test_avg_reconstructed_acc:.4f}")
    print(f"Final Consistency: {test_avg_consistency:.4f}")
    print(f"Best SSIM: {best_ssim:.4f}, Best PSNR: {best_psnr:.2f} dB")

    wandb.finish()


def set_random_seed(seed=42):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # For deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def anneal_function(anneal_fn, step, k, x0, epochs, start=0.001):
    try:
        constant_value = float(anneal_fn)
        return constant_value
    except ValueError:
        pass

    if anneal_fn == "logistic":
        return start + float(1 / (1 + np.exp(-k * (step - x0))))
    elif anneal_fn == "linear":
        steps_per_epoch = x0 / epochs
        current_epoch = step / steps_per_epoch
        return start + 2 * current_epoch / epochs
    elif anneal_fn == "cosine":
        if step < x0:
            return start + 0.5 * (1 - np.cos(np.pi * step / x0))
        else:
            return start + 1.0
    elif anneal_fn == "cyclic":
        cycle_length = x0 / 4
        progress = (step % cycle_length) / cycle_length
        return start + 0.5 * (1 - np.cos(np.pi * progress))
    else:
        raise ValueError(f"Unknown annealing function: {anneal_fn}")


def create_celeba_dataset(batch_size=64, attr_list=None, transform=None):
    train_dataset = CelebADataset(
        root="./data",
        split="train",
        attr_list=attr_list,
        transform=transform,
        download=False,
    )

    val_dataset = CelebADataset(
        root="./data",
        split="valid",
        attr_list=attr_list,
        transform=transform,
        download=False,
    )

    test_dataset = CelebADataset(
        root="./data",
        split="test",
        attr_list=attr_list,
        transform=transform,
        download=False,
    )

    return train_dataset, val_dataset, test_dataset


def plot_denoising_results(
    original,
    noisy,
    reconstructed,
    epoch,
    n_samples=8,
    ssim_scores=None,
    psnr_scores=None,
    log_name="",
):
    original = original[:n_samples]
    noisy = noisy[:n_samples]
    reconstructed = reconstructed[:n_samples]

    if n_samples == 32:
        rows, cols = 4, 8
        fig, axes = plt.subplots(3, 1, figsize=(20, 12))

        for type_idx, (imgs, type_name) in enumerate(
            [
                (original, "Original"),
                (noisy, "Noisy"),
                (reconstructed, "Denoised"),
            ]
        ):
            grid = make_grid(imgs, nrow=cols, padding=2, normalize=False)
            grid_np = grid.permute(1, 2, 0).numpy()

            if grid_np.shape[2] == 1:
                grid_np = grid_np.squeeze(2)
                axes[type_idx].imshow(grid_np, cmap="gray")
            else:
                axes[type_idx].imshow(grid_np)

            axes[type_idx].set_title(
                f"{type_name} Images", fontsize=14, fontweight="bold"
            )
            axes[type_idx].axis("off")

        if ssim_scores is not None and psnr_scores is not None:
            avg_ssim = (
                np.mean(ssim_scores)
                if isinstance(ssim_scores, (list, np.ndarray))
                else ssim_scores
            )
            avg_psnr = (
                np.mean(psnr_scores)
                if isinstance(psnr_scores, (list, np.ndarray))
                else psnr_scores
            )
            title = (
                log_name
                + f"\nAvg SSIM: {avg_ssim:.4f}, Avg PSNR: {avg_psnr:.2f} dB"
            )

        fig.suptitle(title, fontsize=16, fontweight="bold")
        plt.tight_layout()

    else:
        comparison = torch.cat([original, noisy, reconstructed], dim=0)
        grid = make_grid(
            comparison, nrow=n_samples, padding=2, normalize=False
        )

        fig, ax = plt.subplots(figsize=(15, 12))

        grid_np = grid.permute(1, 2, 0).numpy()

        if grid_np.shape[2] == 1:
            grid_np = grid_np.squeeze(2)
            ax.imshow(grid_np, cmap="gray")
        else:
            ax.imshow(grid_np)

        title = f"Denoising Results - Epoch {epoch}\n"
        title += "Row 1: Original Images\n"
        title += "Row 2: Noisy Images\n"
        title += "Row 3: Denoised Images"

        if ssim_scores is not None and psnr_scores is not None:
            avg_ssim = (
                np.mean(ssim_scores)
                if isinstance(ssim_scores, (list, np.ndarray))
                else ssim_scores
            )
            avg_psnr = (
                np.mean(psnr_scores)
                if isinstance(psnr_scores, (list, np.ndarray))
                else psnr_scores
            )
            title = (
                log_name
                + f"\nAvg SSIM: {avg_ssim:.4f}, Avg PSNR: {avg_psnr:.2f} dB"
            )

        ax.set_title(title, fontsize=14)
        ax.axis("off")

    return fig


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Image Denoising Training with SSIM/PSNR Evaluation"
    )

    parser.add_argument(
        "--dataset", type=str, default="celeba", help="Dataset name"
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="ae",
        choices=["ae", "vae", "t3", "pareto", "laplace"],
        help="Model type",
    )
    parser.add_argument(
        "--loss-type",
        type=str,
        default="l1",
        choices=["l1", "l2", "mae", "mse"],
        help="Reconstruction loss type",
    )

    parser.add_argument(
        "--nu",
        type=float,
        default=3.1,
        help="Nu parameter for Pareto VAE models",
    )

    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch-size", type=int, default=512)
    parser.add_argument(
        "--anneal-function",
        type=str,
        default="logistic",
        help="Annealing function for VAE regularization weight",
    )
    parser.add_argument(
        "--normalization",
        type=str,
        default="none",
        choices=[
            "batchnorm",
            "layernorm",
            "instancenorm",
            "groupnorm",
            "none",
        ],
        help="Normalization technique to use",
    )
    parser.add_argument(
        "--activation",
        type=str,
        default="relu",
        choices=["relu", "leaky_relu", "elu", "selu", "gelu", "swish", "tanh"],
        help="Activation function to use",
    )
    parser.add_argument(
        "--dropout-rate",
        type=float,
        default=0.0,
        help="Dropout rate (0.0 = no dropout)",
    )

    parser.add_argument(
        "--noise-type",
        type=str,
        default="salt_pepper",
        choices=["salt_pepper", "gaussian", "mixed"],
        help="Type of noise to add",
    )
    parser.add_argument(
        "--noise-prob",
        type=float,
        default=0.1,
        help="Probability of noise (0.0-1.0)",
    )

    parser.add_argument(
        "--salt-prob",
        type=float,
        default=0.5,
        help="Probability of salt vs pepper in salt_pepper noise",
    )

    parser.add_argument(
        "--gaussian-mean",
        type=float,
        default=0.0,
        help="Mean of gaussian noise",
    )
    parser.add_argument(
        "--gaussian-std",
        type=float,
        default=0.1,
        help="Standard deviation of gaussian noise",
    )
    parser.add_argument(
        "--learning-rate", type=float, default=5e-4, help="Learning rate"
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        default="adam",
        choices=["adam", "adamw", "adam-wd", "adamw-wd", "sgd"],
        help="Optimizer to use",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    args = parser.parse_args()
    main(args)
