import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb
from tqdm import tqdm

# write a dummy custom dataset


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def augment_image(image, augmentation_type=None):
    """
    Apply augmentation to an image tensor.

    Args:
        image: Tensor image of shape [C, H, W]
        augmentation_type: Type of augmentation ('rotate', 'flip', 'unchanged').
                           If None, a random type will be chosen.

    Returns:
        Augmented image tensor
    """
    if augmentation_type is None:
        augmentation_type = random.choice(["rotate", "flip", "unchanged"])

    if augmentation_type == "rotate":
        # Random rotation by 90, 180, or 270 degrees
        angle = random.choice([90, 180, 270])
        return TF.rotate(image, angle)
    elif augmentation_type == "flip":
        # Random horizontal or vertical flip
        if random.random() > 0.5:
            return TF.hflip(image)
        else:
            return TF.vflip(image)
    else:  # 'unchanged'
        return image


def split_real_images_randomly(real_images, split_ratio=0.5):
    """
    Split real images into two mutually exclusive sets randomly.

    Args:
        real_images: List of image tensors
        split_ratio: Ratio for the first set (default 0.5 for equal split)

    Returns:
        Tuple of (set1_images, set2_images)
    """
    indices = list(range(len(real_images)))
    random.shuffle(indices)

    split_point = int(len(indices) * split_ratio)
    set1_indices = indices[:split_point]
    set2_indices = indices[split_point:]

    set1_images = [real_images[i] for i in set1_indices]
    set2_images = [real_images[i] for i in set2_indices]

    return set1_images, set2_images


def augment_to_target_size(images, target_size):
    """
    Augment a set of images to reach the target size by applying random augmentations.

    Args:
        images: List of image tensors
        target_size: Target number of images

    Returns:
        List of image tensors with target_size length
    """
    if len(images) >= target_size:
        # If we already have enough images, randomly sample
        indices = random.sample(range(len(images)), target_size)
        return [images[i] for i in indices]

    # If we need more images, apply augmentations
    augmented_images = images.copy()
    additional_samples_needed = target_size - len(images)

    print(
        f"Applying {additional_samples_needed} augmentations to reach target size {target_size}"
    )

    for _ in range(additional_samples_needed):
        # Select a random image from original images
        base_image = random.choice(images)
        # Apply random augmentation
        augmented_image = augment_image(base_image)
        augmented_images.append(augmented_image)

    return augmented_images


if __name__ == "__main__":
    seed = 7
    cell_type_id = 1
    seed_everything(seed)

    # Load yaml file
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    # Use the same perturbation IDs as in the original script
    all_perturbation_ids = list(range(1, 1139))  # 1 to 1138
    sampled_perturbation_ids = random.sample(all_perturbation_ids, 50)
    # sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    NUM_SAMPLES = 500
    # Iterate through each perturbation ID and calculate metrics
    results = []

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id, cell_type_id=cell_type_id
        )

        if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
            print(f"No real data found for perturbation ID {pert_id}")
            continue

        # Get all available real images
        all_real_images = [
            real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
        ]
        print(
            f"Found {len(all_real_images)} total real images for perturbation ID {pert_id}"
        )

        # Split the real images into two mutually exclusive sets
        real_set1, real_set2 = split_real_images_randomly(
            all_real_images, split_ratio=0.5
        )
        print(
            f"Split into set1: {len(real_set1)} images, set2: {len(real_set2)} images"
        )

        # Augment each set to reach NUM_SAMPLES
        NUM_SAMPLES = max(len(real_set1), len(real_set2))
        real_set1_augmented = augment_to_target_size(real_set1, NUM_SAMPLES)
        real_set2_augmented = augment_to_target_size(real_set2, NUM_SAMPLES)

        print(
            f"After augmentation - Set1: {len(real_set1_augmented)}, Set2: {len(real_set2_augmented)}"
        )

        # Convert to tensors
        real_set1_tensor = torch.stack(real_set1_augmented)
        real_set2_tensor = torch.stack(real_set2_augmented)

        print(f"Calculating metrics for perturbation ID {pert_id}")
        print(f"Real set1 shape: {real_set1_tensor.shape}")
        print(f"Real set2 shape: {real_set2_tensor.shape}")

        # Convert images to RGB format
        real_set1_tensor = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in real_set1_tensor]
        )
        real_set2_tensor = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in real_set2_tensor]
        )

        # Convert to uint8 format
        real_set1_uint8 = (real_set1_tensor * 255).to(torch.uint8)
        real_set2_uint8 = (real_set2_tensor * 255).to(torch.uint8)

        # Create datasets
        real_set1_dataset = CustomDataset(real_set1_uint8)
        real_set2_dataset = CustomDataset(real_set2_uint8)

        # Calculate metrics between the two real sets
        metrics = torch_fidelity.calculate_metrics(
            input1=real_set1_dataset,
            input2=real_set2_dataset,
            cuda=True,
            fid=True,
            kid=False,
            # kid=True,
            # kid_subset_size=100,
            # kid_subsets=100,
        )
        metrics["kernel_inception_distance_mean"] = 0.0
        metrics["kernel_inception_distance_std"] = 0.0
        fid = metrics["frechet_inception_distance"]
        kid_mean = metrics["kernel_inception_distance_mean"]
        kid_std = metrics["kernel_inception_distance_std"]

        print(f"Perturbation ID: {pert_id}")
        print(f"FID (Real vs Real): {fid:.4f}")
        print(f"KID (Real vs Real): {kid_mean:.4f} ± {kid_std:.4f}")

        # Save results
        results.append(
            {
                "perturbation_id": pert_id,
                "num_original_real": len(all_real_images),
                "num_set1": len(real_set1_tensor),
                "num_set2": len(real_set2_tensor),
                "fid": fid,
                "kid_mean": kid_mean,
                "kid_std": kid_std,
            }
        )

    # Create a DataFrame and save results to CSV
    results_df = pd.DataFrame(results)

    # Calculate average metrics
    avg_metrics = {
        "perturbation_id": "Average",
        "num_original_real": results_df["num_original_real"].mean(),
        "num_set1": results_df["num_set1"].mean(),
        "num_set2": results_df["num_set2"].mean(),
        "fid": results_df["fid"].mean(),
        "kid_mean": results_df["kid_mean"].mean(),
        "kid_std": results_df["kid_std"].mean(),
    }

    # Add average row to the DataFrame
    results_df = pd.concat([results_df, pd.DataFrame([avg_metrics])], ignore_index=True)

    output_file = f"perturbation_type_metrics_real_vs_real_results_seed_{seed}.csv"
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print a summary table
    print("\nSummary of Real vs Real Perturbation Metrics:")
    print(results_df.to_string(index=False))

    # Print the average metrics separately for clarity
    print("\nAverage Metrics (Real vs Real):")
    print(f"Average FID: {avg_metrics['fid']:.4f}")
    print(f"Average KID: {avg_metrics['kid_mean']:.4f} ± {avg_metrics['kid_std']:.4f}")

    print(
        "\nNote: These metrics represent the baseline FID/KID between two sets of real images"
    )
    print(
        "from the same perturbations. Lower values indicate better consistency in the real data."
    )
