import random

import pandas as pd
import torch
import torch_fidelity
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def split_real_images_randomly(real_images, num_samples_per_set):
    """
    Split real images into two random, mutually exclusive subsets.

    Args:
        real_images: List of real image tensors
        num_samples_per_set: Number of samples to include in each subset

    Returns:
        Tuple of (set1_images, set2_images) as torch tensors
    """
    total_available = len(real_images)
    total_needed = num_samples_per_set * 2

    if total_available < total_needed:
        print(
            f"Warning: Only {total_available} images available, but need {total_needed}"
        )
        # Use all available images, split them as evenly as possible
        indices = list(range(total_available))
        random.shuffle(indices)
        split_point = total_available // 2
        set1_indices = indices[:split_point]
        set2_indices = indices[split_point:]
    else:
        # Sample the required number of images
        indices = random.sample(range(total_available), total_needed)
        random.shuffle(indices)  # Additional shuffle for good measure
        set1_indices = indices[:num_samples_per_set]
        set2_indices = indices[num_samples_per_set:]

    set1_images = [real_images[i] for i in set1_indices]
    set2_images = [real_images[i] for i in set2_indices]

    return torch.stack(set1_images), torch.stack(set2_images)


if __name__ == "__main__":
    seed = 1337
    seed_everything(seed)
    # load yaml file
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"

    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    cell_types = [0, 1, 2, 3]
    cell_type_to_label = {
        0: "HEPG2",
        1: "HUVEC",
        2: "RPE",
        3: "U2OS",
    }
    NUM_SAMPLES_PER_SET = (
        500  # Each set will have 500 images (1000 total like original)
    )

    # Iterate through each cell type and calculate real vs real metrics
    results = []

    for cell_type in cell_types:
        print(f"\n\n{'='*80}")
        print(
            f"Processing cell type: {cell_type_to_label[cell_type]} (ID: {cell_type})"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(cell_type_id=cell_type)

        if real_filtered_dataset is None:
            print(f"No real data found for cell type {cell_type}")
            continue

        # Get all available real images
        all_real_images = [
            real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
        ]
        print(f"Total available real images: {len(all_real_images)}")

        # Check if we have enough images for two sets
        min_required = NUM_SAMPLES_PER_SET * 2
        if len(all_real_images) < min_required:
            print(
                f"Warning: Only {len(all_real_images)} images available, need at least {min_required}"
            )
            print(f"Will use all available images split into two sets")

        # Split into two random, mutually exclusive subsets
        set1_images, set2_images = split_real_images_randomly(
            all_real_images, NUM_SAMPLES_PER_SET
        )

        print(f"Set 1 images: {len(set1_images)}")
        print(f"Set 2 images: {len(set2_images)}")

        print(f"Calculating real vs real metrics for cell type {cell_type}")
        print(f"Set 1 shape: {set1_images.shape}")
        print(f"Set 2 shape: {set2_images.shape}")

        # Convert to RGB format (same as original script)
        set1_rgb = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in set1_images]
        )
        set2_rgb = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in set2_images]
        )

        # Convert to uint8 format for torch_fidelity
        set1_uint8 = (set1_rgb * 255).to(torch.uint8)
        set2_uint8 = (set2_rgb * 255).to(torch.uint8)

        # Create datasets
        set1_dataset = CustomDataset(set1_uint8)
        set2_dataset = CustomDataset(set2_uint8)

        # Calculate metrics between the two real image sets
        metrics = torch_fidelity.calculate_metrics(
            input1=set1_dataset,
            input2=set2_dataset,
            cuda=True,
            fid=True,
            kid=False,
            # kid=True,
            # kid_subset_size=min(500, min(len(set1_images), len(set2_images))),
            # kid_subsets=100,
        )
        metrics["kernel_inception_distance_mean"] = 0.0
        metrics["kernel_inception_distance_std"] = 0.0
        fid = metrics["frechet_inception_distance"]
        kid_mean = metrics["kernel_inception_distance_mean"]
        kid_std = metrics["kernel_inception_distance_std"]

        print(f"Cell Type: {cell_type_to_label[cell_type]}")
        print(f"Real vs Real FID: {fid:.4f}")
        print(f"Real vs Real KID: {kid_mean:.4f} ± {kid_std:.4f}")

        # Save results
        results.append(
            {
                "cell_type": cell_type_to_label[cell_type],
                "cell_type_id": cell_type,
                "num_original_real": len(all_real_images),
                "num_set1": len(set1_images),
                "num_set2": len(set2_images),
                "fid": fid,
                "kid_mean": kid_mean,
                "kid_std": kid_std,
            }
        )

    # Create a DataFrame and save results to CSV
    results_df = pd.DataFrame(results)
    output_file = f"cell_type_metrics_real_vs_real_results_seed_{seed}.csv"
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print a summary table
    print("\nSummary of Cell Type Real vs Real Metrics:")
    print(results_df.to_string(index=False))
    print(results_df.to_string(index=False))
