import os

import torch
from tqdm.auto import tqdm
from torchgmm.bayes import GaussianMixture
import os

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from generation import *


def get_cifar10_dataloader(
    num_images: int = None, batch_size: int = 1024
) -> DataLoader:
    """Returns DataLoader for CIFAR10 dataset."""
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )

    if num_images is not None:
        from torch.utils.data import Subset

        # Use fixed seed generator for reproducible subset selection
        seed = 42
        torch.manual_seed(seed)
        g = torch.Generator().manual_seed(42)  # Match seed from training_utils.py
        dataset = Subset(dataset, range(num_images))

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        generator=g,  # Added fixed generator for reproducible shuffling
        drop_last=False,  # Added to match training_utils.py
    )


def save_gmm_patches(
    dataloader: DataLoader, kernel_size: int = 32, n_components: int = 3
) -> None:
    """Calculate and save GMM for patches of given kernel size.

    Args:
        dataloader: DataLoader containing the dataset
        kernel_size: Size of the patches to extract
        n_components: Number of GMM components
    """

    # Initialize GMM
    gmm = GaussianMixture(
        num_components=n_components,
        covariance_type="full",
        batch_size=2**21,
        convergence_tolerance=1e-4,
        covariance_regularization=1e-6,
        trainer_params=dict(
            accelerator="gpu",
            devices=1,
            logger=False,
            enable_model_summary=False,
            enable_progress_bar=True,
        ),
    )

    # Process data in batches to avoid memory issues
    all_patches = []
    for data, _ in tqdm(dataloader, desc="Extracting patches", total=len(dataloader)):
        data = data * 2 - 1  # Normalize to [-1, 1]

        # Extract patches
        patches = get_patches(
            [data],
            kernel_size=kernel_size,
            stride=1,
            n_channels=3,
            padding=0,
        )

        # Flatten patches for GMM
        patches_flat = patches.view(patches.shape[0], -1)  # .to("cuda")
        all_patches.append(patches_flat)

    # Concatenate all patches
    all_patches = torch.cat(all_patches, dim=0)

    # Fit GMM
    print(f"Fitting GMM with {n_components} components for kernel size {kernel_size}")
    gmm.fit(all_patches)

    # Save GMM parameters
    save_path = (
        f"trained_models/gmms/cifar10_50000_gmm_ks{kernel_size}_nc{n_components}.pt"
    )
    print(f"Saving GMM to {save_path}")
    torch.save(
        {
            "means": gmm.model_.means,
            "covariances": gmm.model_.covariances,
            "component_probs": gmm.model_.component_probs,
        },
        save_path,
    )

    print(f"GMM saved to {save_path}")
    print("GMM parameters:")
    print(f"Number of components: {n_components}")
    print(f"Means shape: {gmm.model_.means.shape}")
    print(f"Covariances shape: {gmm.model_.covariances.shape}")
    print(f"Component probabilities shape: {gmm.model_.component_probs.shape}")


if __name__ == "__main__":
    dataloader = get_cifar10_dataloader(num_images=50000)
    # dataloader = get_ffhq_dataloader(dataset_dir="data/ffhq_70k", num_images=1000)

    ks_to_try = [5]  # [32, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3]
    n_components = [1024]  # [1, 4, 16, 64, 256, 1024]  # Number of GMM components

    for ks in tqdm(ks_to_try, desc="Kernel size", total=len(ks_to_try)):
        for nc in tqdm(
            n_components, desc="Number of components", total=len(n_components)
        ):
            save_gmm_patches(dataloader, kernel_size=ks, n_components=nc)
