import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from diffusers.models import AutoencoderKL
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb
from sc_perturb.metrics_utils import calculate_metrics_from_scratch
from sc_perturb.models.sit import SiT_models
from sc_perturb.utils.generation_utils import generate_perturbation_matched_samples
from sc_perturb.utils.utils import load_encoders
from tqdm import tqdm

# write a dummy custom dataset


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def split_real_images_randomly(real_images, split_ratio=0.5):
    """
    Split real images into two mutually exclusive sets randomly.

    Args:
        real_images: List of image tensors
        split_ratio: Ratio for the first set (default 0.5 for equal split)

    Returns:
        Tuple of (set1_images, set2_images)
    """
    indices = list(range(len(real_images)))
    random.shuffle(indices)

    split_point = int(len(indices) * split_ratio)
    set1_indices = indices[:split_point]
    set2_indices = indices[split_point:]

    set1_images = [real_images[i] for i in set1_indices]
    set2_images = [real_images[i] for i in set2_indices]

    return set1_images, set2_images


if __name__ == "__main__":
    seed = 0
    cell_type_id = 1
    seed_everything(seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # load yaml file
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    # Example numpy path: generated_path/p<pid>/p<pid>_c<c_id>_sample<sample_id>.npy
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    # Sample 4 random perturbation IDs out of 1138
    all_perturbation_ids = list(range(1, 1139))  # 1 to 1138
    sampled_perturbation_ids = random.sample(all_perturbation_ids, 50)
    # sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    NUM_SAMPLES = 500
    # Iterate through each perturbation ID and calculate metrics
    results = []

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule (no cell_type_id filter to get all cell types like in the original)
        real_filtered_dataset = datamodule.filter_samples(perturbation_id=pert_id)

        if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
            print(f"No real data found for perturbation ID {pert_id}")
            continue

        # Get all available real images
        all_real_images = [
            real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
        ]
        print(
            f"Found {len(all_real_images)} total real images for perturbation ID {pert_id}"
        )

        # Split the real images into two mutually exclusive sets
        real_set1, real_set2 = split_real_images_randomly(
            all_real_images, split_ratio=0.5
        )
        print(
            f"Split into set1: {len(real_set1)} images, set2: {len(real_set2)} images"
        )

        # No augmentation - use the sets as they are
        print(
            f"Using sets without augmentation - Set1: {len(real_set1)}, Set2: {len(real_set2)}"
        )

        # Convert to tensors
        real_set1_tensor = torch.stack(real_set1)
        real_set2_tensor = torch.stack(real_set2)

        print(f"Calculating channel-wise metrics for perturbation ID {pert_id}")
        print(f"Real set1 shape: {real_set1_tensor.shape}")
        print(f"Real set2 shape: {real_set2_tensor.shape}")

        # calculate channelwise metrics
        num_channels = 6
        channel_results = []

        for ch in range(num_channels):
            print(f"Processing channel {ch}...")

            # Extract single channel from both sets
            real_set1_ch = real_set1_tensor[:, ch, :, :]
            real_set2_ch = real_set2_tensor[:, ch, :, :]

            # Stack three times to create RGB-like images for metrics calculation
            real_set1_ch_rgb = torch.stack(
                [real_set1_ch, real_set1_ch, real_set1_ch], dim=1
            )
            real_set2_ch_rgb = torch.stack(
                [real_set2_ch, real_set2_ch, real_set2_ch], dim=1
            )

            # Convert to uint8 format
            real_set1_ch_uint8 = (real_set1_ch_rgb * 255).to(torch.uint8)
            real_set2_ch_uint8 = (real_set2_ch_rgb * 255).to(torch.uint8)

            # Create datasets
            real_set1_ch_dataset = CustomDataset(real_set1_ch_uint8)
            real_set2_ch_dataset = CustomDataset(real_set2_ch_uint8)

            # Calculate metrics between the two real sets for this channel
            metrics_ch = torch_fidelity.calculate_metrics(
                input1=real_set1_ch_dataset,
                input2=real_set2_ch_dataset,
                cuda=True,
                fid=True,
                kid=False,
                # kid=True,
                # kid_subset_size=100,
                # kid_subsets=100,
            )
            metrics_ch["kernel_inception_distance_mean"] = 0.0
            metrics_ch["kernel_inception_distance_std"] = 0.0

            ch_fid = metrics_ch["frechet_inception_distance"]
            ch_kid_mean = metrics_ch["kernel_inception_distance_mean"]
            ch_kid_std = metrics_ch["kernel_inception_distance_std"]

            print(f"Channel {ch} FID (Real vs Real): {ch_fid:.4f}")
            print(
                f"Channel {ch} KID (Real vs Real): {ch_kid_mean:.4f} ± {ch_kid_std:.4f}"
            )

            # Save channel-wise results
            channel_results.append(
                {
                    "perturbation_id": pert_id,
                    "channel": ch,
                    "fid": ch_fid,
                    "kid_mean": ch_kid_mean,
                    "kid_std": ch_kid_std,
                }
            )

        # Convert images to RGB format for overall metrics
        real_set1_rgb = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in real_set1_tensor]
        )
        real_set2_rgb = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in real_set2_tensor]
        )

        # Convert to uint8 format
        real_set1_uint8 = (real_set1_rgb * 255).to(torch.uint8)
        real_set2_uint8 = (real_set2_rgb * 255).to(torch.uint8)

        # Create datasets
        real_set1_dataset = CustomDataset(real_set1_uint8)
        real_set2_dataset = CustomDataset(real_set2_uint8)

        # Calculate overall metrics between the two real sets
        metrics = torch_fidelity.calculate_metrics(
            input1=real_set1_dataset,
            input2=real_set2_dataset,
            cuda=True,
            fid=True,
            kid=False,
            # kid=True,
            # kid_subset_size=100,
            # kid_subsets=100,
        )
        metrics["kernel_inception_distance_mean"] = 0.0
        metrics["kernel_inception_distance_std"] = 0.0

        fid = metrics["frechet_inception_distance"]
        kid_mean = metrics["kernel_inception_distance_mean"]
        kid_std = metrics["kernel_inception_distance_std"]

        print(f"Perturbation ID: {pert_id}")
        print(f"Overall FID (Real vs Real): {fid:.4f}")
        print(f"Overall KID (Real vs Real): {kid_mean:.4f} ± {kid_std:.4f}")

        # Organize channel metrics for this perturbation
        ch_metrics = {}
        for ch in range(num_channels):
            ch_data = next(
                (
                    item
                    for item in channel_results
                    if item["perturbation_id"] == pert_id and item["channel"] == ch
                ),
                None,
            )
            if ch_data:
                ch_metrics[f"channel_{ch}_fid"] = ch_data["fid"]
                ch_metrics[f"channel_{ch}_kid_mean"] = ch_data["kid_mean"]
                ch_metrics[f"channel_{ch}_kid_std"] = ch_data["kid_std"]

        # Save combined results with channel metrics
        result_entry = {
            "perturbation_id": pert_id,
            "num_original_real": len(all_real_images),
            "num_set1": len(real_set1_tensor),
            "num_set2": len(real_set2_tensor),
            "fid": fid,
            "kid_mean": kid_mean,
            "kid_std": kid_std,
            **ch_metrics,  # Include all channel metrics
        }
        results.append(result_entry)

    # Create a DataFrame and save results to CSV
    results_df = pd.DataFrame(results)

    # Calculate average metrics
    avg_metrics = {
        "perturbation_id": "Average",
        "num_original_real": results_df["num_original_real"].mean(),
        "num_set1": results_df["num_set1"].mean(),
        "num_set2": results_df["num_set2"].mean(),
        "fid": results_df["fid"].mean(),
        "kid_mean": results_df["kid_mean"].mean(),
        "kid_std": results_df["kid_std"].mean(),
    }

    # Add channel averages to the overall average metrics
    for ch in range(num_channels):
        if f"channel_{ch}_fid" in results_df.columns:
            avg_metrics[f"channel_{ch}_fid"] = results_df[f"channel_{ch}_fid"].mean()
            avg_metrics[f"channel_{ch}_kid_mean"] = results_df[
                f"channel_{ch}_kid_mean"
            ].mean()
            avg_metrics[f"channel_{ch}_kid_std"] = results_df[
                f"channel_{ch}_kid_std"
            ].mean()

    # Add average row to the DataFrame
    results_df = pd.concat([results_df, pd.DataFrame([avg_metrics])], ignore_index=True)

    output_file = (
        f"perturbation_type_channel_metrics_real_vs_real_results_seed_{seed}.csv"
    )
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print a summary table
    print("\nSummary of Real vs Real Channel-wise Perturbation Metrics:")
    print(results_df.to_string(index=False))

    # Print the average metrics separately for clarity
    print("\nAverage Metrics (Real vs Real):")
    print(f"Average Overall FID: {avg_metrics['fid']:.4f}")
    print(
        f"Average Overall KID: {avg_metrics['kid_mean']:.4f} ± {avg_metrics['kid_std']:.4f}"
    )

    # Create a DataFrame for channel results and save to CSV
    channel_results_df = pd.DataFrame(channel_results)

    # Calculate average metrics per channel
    channel_avg_metrics = []
    for ch in range(num_channels):
        ch_data = channel_results_df[channel_results_df["channel"] == ch]
        channel_avg_metrics.append(
            {
                "perturbation_id": "Average",
                "channel": ch,
                "fid": ch_data["fid"].mean(),
                "kid_mean": ch_data["kid_mean"].mean(),
                "kid_std": ch_data["kid_std"].mean(),
            }
        )

    # Add average rows to the channel DataFrame
    channel_results_df = pd.concat(
        [channel_results_df, pd.DataFrame(channel_avg_metrics)], ignore_index=True
    )

    # Save channel-wise results
    channel_output_file = f"perturbation_type_channel_metrics_real_vs_real_channel_details_seed_{seed}.csv"
    channel_results_df.to_csv(channel_output_file, index=False)
    print(f"\nChannel-wise detailed results saved to {channel_output_file}")

    # Print a summary of channel-wise metrics
    print("\nSummary of Channel-wise Metrics (Real vs Real):")
    for ch in range(num_channels):
        ch_avg = channel_avg_metrics[ch]
        print(
            f"Channel {ch} - Average FID: {ch_avg['fid']:.4f}, Average KID: {ch_avg['kid_mean']:.4f} ± {ch_avg['kid_std']:.4f}"
        )

    print(
        "\nNote: These metrics represent the baseline FID/KID between two sets of real images"
    )
    print(
        "from the same perturbations on a per-channel basis. Lower values indicate better consistency in the real data per channel."
    )
