import os
from typing import List, Optional, Union
import json
import argparse

import torch
from diffusers import AutoencoderKL, DDIMScheduler
from generation import *
from pyramids import GaussianPyramid, IdentityPyramid
from torch.nn import functional as F
from tqdm.auto import tqdm
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from data import get_dataset_loader


def get_dataset_resolution(dataset_name: str) -> int:
    """Get the native resolution for each dataset."""
    resolutions = {
        "mnist": 28,
        "fashion_mnist": 28,
        "cifar10": 32,
        "ffhq": 64,  # We're using the 64x64 version
        "celeba_hq": 64,  # We're using the 64x64 version
        "afhq": 64,
    }
    return resolutions[dataset_name]


def save_s_matrix(
    dataloader: DataLoader,
    dataset_name: str,
    kernel_size: int,
    num_images: int,
    save_dir: str = "data",
) -> torch.Tensor:
    """
    Calculate and save the S (covariance) matrix and mean for a dataset.

    Args:
        dataloader: DataLoader for the dataset
        dataset_name: Name of the dataset for saving
        kernel_size: Size of patches to extract
        num_images: Number of images used (-1 for full dataset)
        save_dir: Directory to save the matrix
    """
    S = 0.0
    mean = 0.0
    patch_count = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Get a sample batch to determine image properties
    sample_batch = next(iter(dataloader))[0]
    n_channels = sample_batch.shape[1]

    # First pass: calculate mean
    print("Calculating patch mean...")
    for data, _ in tqdm(dataloader, desc="First pass - mean calculation"):
        # Move data to device and normalize to [-1, 1]
        data = data.to(device)
        if data.min() >= 0 and data.max() <= 1:
            data = data * 2 - 1

        # Extract patches
        patches = get_patches(
            [data],
            kernel_size=kernel_size,
            stride=1,
            n_channels=n_channels,
            padding=0,
        )
        patches = patches.view(patches.shape[0], -1)

        # Update mean
        mean += patches.sum(dim=0)
        patch_count += patches.shape[0]

    # Finalize mean
    mean = mean / patch_count

    # Second pass: calculate covariance
    print("Calculating covariance matrix...")
    for data, _ in tqdm(dataloader, desc="Second pass - covariance calculation"):
        # Move data to device and normalize to [-1, 1]
        data = data.to(device)
        if data.min() >= 0 and data.max() <= 1:
            data = data * 2 - 1

        # Extract patches
        patches = get_patches(
            [data],
            kernel_size=kernel_size,
            stride=1,
            n_channels=n_channels,
            padding=0,
        )
        patches = patches.view(patches.shape[0], -1)

        # Center the patches
        patches = patches - mean.unsqueeze(0)

        # Calculate covariance
        S += torch.einsum("bi,bj->ij", patches, patches)

    # Normalize covariance
    S = S / patch_count
    S = S.expand(1, -1, -1)
    mean = mean.expand(1, -1)

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Save matrices with appropriate names
    dataset_size_str = "full" if num_images == -1 else str(num_images)
    base_path = os.path.join(save_dir, f"{dataset_name}_{dataset_size_str}")

    cov_path = f"{base_path}_s_matrix_ks{kernel_size}.pt"
    mean_path = f"{base_path}_mean_ks{kernel_size}.pt"

    torch.save(S, cov_path)
    torch.save(mean, mean_path)

    # Print matrix properties
    print(f"\nS matrix and mean properties for {dataset_name}:")
    print(f"Covariance shape: {S.shape}")
    print(f"Mean shape: {mean.shape}")
    print(
        f"Mean statistics: min={mean.min():.4f}, max={mean.max():.4f}, avg={mean.mean():.4f}"
    )
    print(f"Has NaN values: {torch.isnan(S).any() or torch.isnan(mean).any()}")
    print(f"Has Inf values: {torch.isinf(S).any() or torch.isinf(mean).any()}")
    print(f"Matrix symmetry check: {torch.allclose(S, S.transpose(1, 2))}")
    print(f"Files saved to:\n  {cov_path}\n  {mean_path}\n")

    return S, mean


def parse_args():
    parser = argparse.ArgumentParser(
        description="Calculate S matrix for various datasets"
    )

    parser.add_argument(
        "--dataset",
        type=str,
        choices=["mnist", "fashion_mnist", "cifar10", "ffhq", "celeba_hq", "afhq"],
        required=True,
        help="Dataset to calculate S matrix for",
    )
    parser.add_argument(
        "--num-images",
        type=int,
        default=-1,
        help="Number of images to use (-1 for full dataset, default: -1)",
    )
    parser.add_argument(
        "--kernel-sizes",
        type=int,
        nargs="+",
        help="Kernel sizes to use for patch extraction. If not specified, uses the dataset's native resolution.",
    )
    parser.add_argument(
        "--save-dir",
        type=str,
        default="data",
        help="Directory to save S matrices (default: 'data')",
    )
    # Dataset-specific paths
    parser.add_argument(
        "--ffhq-dir",
        type=str,
        default="data/ffhq_70k",
        help="Directory containing FFHQ dataset",
    )
    parser.add_argument(
        "--celeba-dir",
        type=str,
        default="data/celebahq-resized-256x256/versions/1/celeba_hq_256",
        help="Directory containing CelebA-HQ dataset",
    )
    parser.add_argument(
        "--afhq-dir",
        type=str,
        default="./data/afhq",
        help="Directory containing AFHQ dataset",
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    # If kernel_sizes not specified, use the dataset's native resolution
    if args.kernel_sizes is None:
        args.kernel_sizes = [get_dataset_resolution(args.dataset)]

    # Get the appropriate dataloader
    dataloader = get_dataset_loader(
        dataset_name=args.dataset,
        num_images=args.num_images,
        batch_size=256,  # Smaller batch size to handle large kernel sizes
        ffhq_dir=args.ffhq_dir,
        celeba_dir=args.celeba_dir,
        afhq_dir=args.afhq_dir,
    )

    # Calculate S matrix for each kernel size
    for ks in tqdm(args.kernel_sizes, desc="Processing kernel sizes"):
        try:
            S, mean = save_s_matrix(
                dataloader=dataloader,
                dataset_name=args.dataset,
                kernel_size=ks,
                num_images=args.num_images,
                save_dir=args.save_dir,
            )
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"\nSkipping kernel size {ks} due to GPU memory constraints")
                continue
            raise e
