import argparse
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch
from ablation import (
    get_dataset_loader,
    calculate_l2_distances,
    calculate_r2_score,
)
from config import get_dataset_config, get_unet_config
from denoising_pyramid import *
from diffusers import DDIMScheduler
from tqdm import tqdm

from nn_baselines.src.training_utils import load_model


def preload_denoisers(
    candidate_kernel_sizes: List[int],
    num_images: int,
    device: str = "cuda",
    dataset_name: str = "cifar10",
) -> Dict[int, DenoisingWiener]:
    """
    Pre-load and train denoisers for each kernel size.

    Args:
        candidate_kernel_sizes: List of kernel sizes to try
        num_images: Number of images in the dataset
        device: Device to run on
        dataset_name: Name of the dataset to use

    Returns:
        Dictionary mapping kernel sizes to trained DenoisingKamb models
    """
    denoisers = {}
    dataloader = get_dataset_loader(
        dataset_name=dataset_name, num_images=num_images, batch_size=32
    )
    dataset_config = get_dataset_config(dataset_name)

    # Format save_prefix to use "full" when num_images is -1
    num_images_str = "full" if num_images == -1 else str(num_images)
    save_prefix = f"{dataset_name}_{num_images_str}"

    for ks in tqdm(candidate_kernel_sizes, desc="Pre-loading denoisers"):
        kamb_config = {
            "resolution": dataset_config["img_size"],
            "device": device,
            "denoiser": "knn",
            "temperature": 1.0,
            "latent_diffusion": False,
            "stride": 1,
            "sigma_correction": True,
            "level_mixture_alpha": 1.0,
            "random_padding": False,
            "kernel_size": [ks],
            "kernel_overlap": 0.75,
            "stride_gen": [1],
            "cond_gamma": [
                0.56,
                0.249,
                0.453,
                0.749,
                0.881,
                47.14,
                59.55,
                24.86,
                12.84,
                6.0,
            ],
            "num_steps": 10,
            "denoiser_args": {"num_neighbors": 200},
            "fill_in_zeros_in_x0": True,
            "embed_w": -1.0,
            "aggregation_mode": "mean",
            "save_dir": "trained_models",
            "dataset_name": dataset_name,
            "save_prefix": save_prefix,
            "in_channels": dataset_config["in_channels"],
            "out_channels": dataset_config["out_channels"],
        }

        model = DenoisingKamb(**kamb_config)
        model.train(dataloader)
        denoisers[ks] = model

    return denoisers


def fit_kernel_sizes(
    num_images: int,
    device: str = "cuda",
    num_steps: int = 1000,
    dataset_name: str = "cifar10",
    batch_size: int = 6,
) -> Tuple[List[int], List[float], List[torch.Tensor]]:
    """
    Fit kernel sizes for DenoisingKamb by comparing against UNet predictions.
    Builds trajectory step by step, optimizing kernel size at each step.

    Args:
        num_images: Number of images in the dataset
        device: Device to run on
        num_steps: Number of diffusion steps
        dataset_name: Name of the dataset to use
        candidate_kernel_sizes: List of kernel sizes to try (if None, uses dataset defaults)
        batch_size: Batch size for processing

    Returns:
        Tuple of (best kernel sizes per step, corresponding losses, trajectory)
    """
    # Initialize UNet model and scheduler
    config = get_unet_config(dataset_name, num_images)
    model = load_model(config, device)
    model.eval()

    scheduler = DDIMScheduler(
        beta_start=config["beta_1"],
        beta_end=config["beta_T"],
        beta_schedule="linear",
        prediction_type="epsilon",
    )
    scheduler.set_timesteps(num_steps)

    # Get dataset-specific configuration
    dataset_config = get_dataset_config(dataset_name)

    # Use dataset-specific kernel sizes from ablation.py
    if dataset_name in ["mnist", "fashion_mnist"]:
        candidate_kernel_sizes = [28, 23, 17, 13, 9, 5, 3]
    elif dataset_name == "cifar10":
        candidate_kernel_sizes = [32, 32, 32, 29, 25, 17, 13, 9, 7, 3]
    else:  # ffhq, celeba_hq, afhq
        candidate_kernel_sizes = [64, 45, 33, 25, 17, 9, 5, 3]

    # Pre-load all denoisers
    denoisers = preload_denoisers(
        candidate_kernel_sizes, num_images, device, dataset_name
    )

    # Initialize lists to store results
    best_kernel_sizes = []
    best_losses = []
    trajectory = []

    # Start with random noise
    cur_img = torch.randn(
        batch_size,
        dataset_config["in_channels"],
        dataset_config["img_size"],
        dataset_config["img_size"],
        device=device,
    )
    cur_img = cur_img * scheduler.init_noise_sigma
    trajectory.append(cur_img.clone())

    # For each timestep
    for t in tqdm(scheduler.timesteps, desc="Building trajectory"):
        # Get UNet prediction
        with torch.no_grad():
            unet_output = model(cur_img, t.to(device)[None])
            step_output = scheduler.step(
                model_output=unet_output,
                timestep=t,
                sample=cur_img,
                generator=None,
            )
            unet_pred = step_output.pred_original_sample

        # Try each kernel size
        step_losses = []
        for ks in candidate_kernel_sizes:
            # Get Kamb prediction using pre-loaded denoiser
            with torch.no_grad():
                predicted_noise, _, _, _ = denoisers[ks].denoise(cur_img, t.item())
                kamb_pred = scheduler.step(
                    model_output=predicted_noise,
                    timestep=t,
                    sample=cur_img,
                    generator=None,
                ).pred_original_sample

            # Calculate loss
            loss = calculate_r2_score(predicted_noise, unet_output)
            step_losses.append(loss)

            print(f"Step {t.item():4d} | Kernel size {ks:2d} | Loss: {loss:.6f}")

        # Find best kernel size for this step
        best_idx = np.argmin(step_losses)
        best_ks = candidate_kernel_sizes[best_idx]
        best_kernel_sizes.append(best_ks)
        best_losses.append(step_losses[best_idx])

        print(
            f"Step {t.item():4d} | Best kernel size: {best_ks} | Best loss: {best_losses[-1]:.6f}"
        )
        print("-" * 50)

        # Step forward with best kernel size
        with torch.no_grad():
            predicted_noise, _, _, _ = denoisers[best_ks].denoise(cur_img, t.item())
            cur_img = scheduler.step(
                model_output=predicted_noise,
                timestep=t,
                sample=cur_img,
                generator=None,
            ).prev_sample
            trajectory.append(cur_img.clone())

    return best_kernel_sizes, best_losses, trajectory


def main():
    parser = argparse.ArgumentParser(
        description="Fit kernel sizes for DenoisingKamb against UNet predictions"
    )
    parser.add_argument(
        "--num_images",
        type=int,
        default=-1,
        help="Number of images in the dataset (-1 for full dataset)",
    )
    parser.add_argument(
        "--num_steps", type=int, default=10, help="Number of diffusion steps"
    )
    parser.add_argument(
        "--device", type=str, default="cuda", help="Device to run on (cuda/cpu)"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["mnist", "fashion_mnist", "cifar10", "ffhq", "celeba_hq", "afhq"],
        default="cifar10",
        help="Dataset to use",
    )
    parser.add_argument(
        "--batch_size", type=int, default=6, help="Batch size for processing"
    )

    args = parser.parse_args()

    # Fit kernel sizes and build trajectory
    print("Fitting kernel sizes and building trajectory...")
    best_kernel_sizes, best_losses, trajectory = fit_kernel_sizes(
        num_images=args.num_images,
        device=args.device,
        num_steps=args.num_steps,
        dataset_name=args.dataset,
        batch_size=args.batch_size,
    )

    # Print final summary
    print("\nFinal Results:")
    print("=" * 50)
    print(f"Best kernel sizes across steps: {best_kernel_sizes}")
    print(f"Corresponding losses: {best_losses}")
    print(f"Average best kernel size: {np.mean(best_kernel_sizes):.2f}")
    print(f"Average loss: {np.mean(best_losses):.6f}")

    # Save results
    save_dir = Path("results")
    save_dir.mkdir(exist_ok=True)
    np.save(
        save_dir / f"kernel_sizes_{args.dataset}_{args.num_images}.npy",
        best_kernel_sizes,
    )
    np.save(save_dir / f"losses_{args.dataset}_{args.num_images}.npy", best_losses)


if __name__ == "__main__":
    main()
