import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb
from tqdm import tqdm

# write a dummy custom dataset


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def augment_image(image, augmentation_type=None):
    """
    Apply augmentation to an image tensor.

    Args:
        image: Tensor image of shape [C, H, W]
        augmentation_type: Type of augmentation ('rotate', 'flip', 'unchanged').
                           If None, a random type will be chosen.

    Returns:
        Augmented image tensor
    """
    if augmentation_type is None:
        augmentation_type = random.choice(["rotate", "flip", "unchanged"])

    if augmentation_type == "rotate":
        # Random rotation by 90, 180, or 270 degrees
        angle = random.choice([90, 180, 270])
        return TF.rotate(image, angle)
    elif augmentation_type == "flip":
        # Random horizontal or vertical flip
        if random.random() > 0.5:
            return TF.hflip(image)
        else:
            return TF.vflip(image)
    else:  # 'unchanged'
        return image


def augment_to_target_size(images, target_size):
    """
    Augment a set of images to reach the target size by applying random augmentations.

    Args:
        images: List of image tensors
        target_size: Target number of images

    Returns:
        List of image tensors with target_size length
    """
    if len(images) >= target_size:
        # If we already have enough images, randomly sample
        indices = random.sample(range(len(images)), target_size)
        return [images[i] for i in indices]

    # If we need more images, apply augmentations
    augmented_images = images.copy()
    additional_samples_needed = target_size - len(images)

    print(
        f"Applying {additional_samples_needed} augmentations to reach target size {target_size}"
    )

    for _ in range(additional_samples_needed):
        # Select a random image from original images
        base_image = random.choice(images)
        # Apply random augmentation
        augmented_image = augment_image(base_image)
        augmented_images.append(augmented_image)

    return augmented_images


def load_and_prepare_perturbation_data(datamodule, pert_id, cell_type_id, num_samples):
    """
    Load and prepare data for a specific perturbation.

    Args:
        datamodule: CellDataModule instance
        pert_id: Perturbation ID
        cell_type_id: Cell type ID
        num_samples: Target number of samples

    Returns:
        Tensor of prepared images, or None if no data found
    """
    # Filter real images using CellDataModule
    real_filtered_dataset = datamodule.filter_samples(
        perturbation_id=pert_id, cell_type_id=cell_type_id
    )

    if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
        print(f"No real data found for perturbation ID {pert_id}")
        return None

    # Get all available real images
    all_real_images = [
        real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
    ]
    print(
        f"Found {len(all_real_images)} total real images for perturbation ID {pert_id}"
    )

    # Augment to target size
    augmented_images = augment_to_target_size(all_real_images, num_samples)
    print(
        f"After augmentation: {len(augmented_images)} images for perturbation ID {pert_id}"
    )

    # Convert to tensor
    images_tensor = torch.stack(augmented_images)

    # Convert to RGB format
    images_tensor = torch.stack(
        [to_rgb(img.cpu()[None]).squeeze(0) for img in images_tensor]
    )

    # Convert to uint8 format
    images_uint8 = (images_tensor * 255).to(torch.uint8)

    return images_uint8


if __name__ == "__main__":
    seed = 7
    cell_type_id = 1
    seed_everything(seed)

    # Load yaml file
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    # Use the same perturbation IDs as in the original script
    sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    NUM_SAMPLES = 500

    # Pre-load all perturbation data to avoid repeated loading
    print("Pre-loading all perturbation data...")
    perturbation_data = {}

    for pert_id in sampled_perturbation_ids:
        print(f"\nLoading data for perturbation ID: {pert_id}")
        data = load_and_prepare_perturbation_data(
            datamodule, pert_id, cell_type_id, NUM_SAMPLES
        )
        if data is not None:
            perturbation_data[pert_id] = data
        else:
            print(f"Skipping perturbation ID {pert_id} due to insufficient data")

    print(f"\nSuccessfully loaded data for {len(perturbation_data)} perturbations")

    # Calculate cross-perturbation metrics
    results = []

    # For each perturbation p1, compare against all other perturbations p2
    for i, p1 in enumerate(sampled_perturbation_ids):
        if p1 not in perturbation_data:
            continue

        for j, p2 in enumerate(sampled_perturbation_ids):
            if p2 not in perturbation_data:
                continue

            # Skip self-comparison (we already have intra-perturbation results from the other script)
            # if p1 == p2:
            #     continue

            print(f"\n{'='*80}")
            print(f"Calculating metrics: P{p1} vs P{p2}")
            print(f"{'='*80}")

            # Get data for both perturbations
            p1_data = perturbation_data[p1]
            p1_data = p1_data[: NUM_SAMPLES // 2]  # Ensure we only take NUM_SAMPLES
            p2_data = perturbation_data[p2]
            p2_data = p2_data[NUM_SAMPLES // 2 :]  # Ensure we only take NUM_SAMPLES

            print(f"P{p1} data shape: {p1_data.shape}")
            print(f"P{p2} data shape: {p2_data.shape}")

            # Create datasets
            p1_dataset = CustomDataset(p1_data)
            p2_dataset = CustomDataset(p2_data)

            # Calculate metrics between the two perturbations
            metrics = torch_fidelity.calculate_metrics(
                input1=p1_dataset,
                input2=p2_dataset,
                cuda=True,
                fid=True,
                kid=True,
                kid_subset_size=100,
                kid_subsets=100,
            )
            fid = metrics["frechet_inception_distance"]
            kid_mean = metrics["kernel_inception_distance_mean"]
            kid_std = metrics["kernel_inception_distance_std"]

            print(f"P{p1} vs P{p2}")
            print(f"FID (Cross-perturbation): {fid:.4f}")
            print(f"KID (Cross-perturbation): {kid_mean:.4f} ± {kid_std:.4f}")

            # Save results
            results.append(
                {
                    "perturbation_p1": p1,
                    "perturbation_p2": p2,
                    "comparison": f"P{p1}_vs_P{p2}",
                    "num_samples_p1": len(p1_data),
                    "num_samples_p2": len(p2_data),
                    "fid": fid,
                    "kid_mean": kid_mean,
                    "kid_std": kid_std,
                }
            )

    # Create a DataFrame and save results to CSV
    results_df = pd.DataFrame(results)

    # Calculate summary statistics
    summary_stats = {
        "comparison": "Summary_Stats",
        "perturbation_p1": "All",
        "perturbation_p2": "All",
        "num_samples_p1": results_df["num_samples_p1"].mean(),
        "num_samples_p2": results_df["num_samples_p2"].mean(),
        "fid": results_df["fid"].mean(),
        "kid_mean": results_df["kid_mean"].mean(),
        "kid_std": results_df["kid_std"].mean(),
    }

    # Calculate min/max for FID bounds analysis
    min_fid = results_df["fid"].min()
    max_fid = results_df["fid"].max()
    mean_fid = results_df["fid"].mean()
    std_fid = results_df["fid"].std()

    # Add summary row to the DataFrame
    results_df = pd.concat(
        [results_df, pd.DataFrame([summary_stats])], ignore_index=True
    )

    output_file = (
        f"perturbation_type_metrics_real_vs_real_cross_analysis_seed_{seed}.csv"
    )
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print a summary table
    print("\nSummary of Cross-Perturbation Metrics:")
    print(results_df.to_string(index=False))

    # Print detailed analysis
    print(f"\n{'='*80}")
    print("CROSS-PERTURBATION FID ANALYSIS")
    print(f"{'='*80}")
    print(
        f"Number of cross-perturbation comparisons: {len(results_df) - 1}"
    )  # -1 for summary row
    print(f"Average FID (cross-perturbation): {mean_fid:.4f} ± {std_fid:.4f}")
    print(f"Min FID (cross-perturbation): {min_fid:.4f}")
    print(f"Max FID (cross-perturbation): {max_fid:.4f}")

    print(f"\n{'='*80}")
    print("EXPECTED FID BOUNDS FOR GENERATED IMAGES:")
    print(f"{'='*80}")
    print("For good generated images, you should expect:")
    print(f"FID_intra_perturbation < FID_generated < FID_cross_perturbation")
    print(f"Lower bound (intra-perturbation): ~0-50 (from previous analysis)")
    print(f"Upper bound (cross-perturbation): {min_fid:.4f} - {max_fid:.4f}")
    print(f"Average upper bound: {mean_fid:.4f}")

    print(f"\nNote: These cross-perturbation FID values represent the expected")
    print(f"upper bounds when comparing images from different perturbation types.")
    print(
        f"Your generated images should have FID values between the intra-perturbation"
    )
    print(f"baseline and these cross-perturbation values.")
