import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch_fidelity
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb
from sc_perturb.metrics_utils import calculate_metrics_from_scratch
from tqdm import tqdm

# write a dummy custom dataset


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def find_generated_files_by_cell_type(generated_path, cell_type_id):
    """
    Find all generated numpy files for a specific cell type across all perturbation folders.

    Args:
        generated_path: Path to the directory containing generated data
        cell_type_id: Cell type ID to filter by (0-3)

    Returns:
        List of file paths matching the cell type
    """
    all_files = []

    # Pattern to match in npy filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Get all perturbation folders
    pert_folders = [
        f
        for f in os.listdir(generated_path)
        if os.path.isdir(os.path.join(generated_path, f))
    ]

    for pert_folder in pert_folders:
        pert_path = os.path.join(generated_path, pert_folder)
        # Find all .npy files in the perturbation folder
        npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

        # Filter files that match the cell type pattern
        cell_type_files = [f for f in npy_files if pattern in f]
        all_files.extend(cell_type_files)

    return all_files


def load_numpy_files(file_paths, max_samples):
    """
    Load a random subset of numpy files into a torch tensor.

    Args:
        file_paths: List of numpy file paths to load
        max_samples: Maximum number of samples to load

    Returns:
        Torch tensor containing the loaded data
    """
    # Randomly sample file paths if there are more than max_samples
    if len(file_paths) > max_samples:
        file_paths = random.sample(file_paths, max_samples)

    # Load the numpy files
    data = []
    for file_path in tqdm(file_paths, desc="Loading numpy files"):
        try:
            img = np.load(file_path)
            data.append(torch.from_numpy(img).float())
        except Exception as e:
            print(f"Error loading {file_path}: {e}")

    # Stack into a single tensor
    return torch.stack(data) if data else None


if __name__ == "__main__":
    seed = 1337
    seed_everything(seed)
    # load yaml file
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    # Example numpy path: generated_path/p<pid>/p<pid>_c<c_id>_sample<sample_id>.npy
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    cell_types = [0, 1, 2, 3]
    cell_type_to_label = {
        "HEPG2": 0,
        "HUVEC": 1,
        "RPE": 2,
        "U2OS": 3,
    }
    NUM_SAMPLES = 500
    # Iterate through each cell type and calculate metrics
    results = []

    for cell_type in cell_types:
        print(f"\n\n{'='*80}")
        print(
            f"Processing cell type: {list(cell_type_to_label.keys())[cell_type]} (ID: {cell_type})"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(cell_type_id=cell_type)

        if real_filtered_dataset is None:
            print(f"No real data found for cell type {cell_type}")
            continue
        # Sample up to NUM_SAMPLES real images for consistency
        if len(real_filtered_dataset) > NUM_SAMPLES:
            indices = random.sample(range(len(real_filtered_dataset)), NUM_SAMPLES)
            real_images = [real_filtered_dataset[i][0] for i in indices]
            # sample randomly half of the reals
            print(
                f"Sampled {NUM_SAMPLES} real images from {len(real_filtered_dataset)} total"
            )
        else:
            indices = list(range(len(real_filtered_dataset)))
            real_images = [real_filtered_dataset[i][0] for i in indices]
            print(f"Using all {len(real_images)} available real images")

        # Convert to tensor
        real_images_tensor = torch.stack(real_images)

        # Find all generated files for this cell type
        generated_files = find_generated_files_by_cell_type(generated_path, cell_type)
        print(f"Found {len(generated_files)} generated files for cell type {cell_type}")

        if not generated_files:
            print(f"No generated data found for cell type {cell_type}")
            continue

        # Load generated images (sample up to NUM_SAMPLES)
        generated_images_tensor = load_numpy_files(
            generated_files, max_samples=NUM_SAMPLES
        )

        if generated_images_tensor is None:
            print(f"Failed to load generated images for cell type {cell_type}")
            continue

        print(f"Calculating metrics for cell type {cell_type}")
        print(f"Real images shape: {real_images_tensor.shape}")
        print(f"Generated images shape: {generated_images_tensor.shape}")

        # Calculate metrics
        # fid, kid_mean, kid_std = calculate_metrics_from_scratch(
        #     real_images_tensor,
        #     generated_images_tensor,
        #     batch_size=16,
        #     kid_subsets=100,
        #     kid_subset_size=500,
        #     feature_extractor="inception_v3",
        # )

        # create dummy dataset
        # real_images_tensor = torch.randn(500, 3, 256, 256)
        # generated_images_tensor = torch.randn(500, 3, 256, 256)
        real_images_tensor = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in real_images_tensor]
        )
        generated_images_tensor = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in generated_images_tensor]
        )
        real_uint8 = (real_images_tensor * 255).to(torch.uint8)
        fake_uint8 = (generated_images_tensor * 255).to(torch.uint8)
        # dataset
        real_images_dataset = CustomDataset(real_uint8)
        generated_images_dataset = CustomDataset(fake_uint8)
        metrics = torch_fidelity.calculate_metrics(
            input1=real_images_dataset,
            input2=generated_images_dataset,
            cuda=True,
            fid=True,
            kid=True,
            kid_subset_size=500,
            kid_subsets=100,
        )
        fid = metrics["frechet_inception_distance"]
        kid_mean = metrics["kernel_inception_distance_mean"]
        kid_std = metrics["kernel_inception_distance_std"]

        print(f"Cell Type: {list(cell_type_to_label.keys())[cell_type]}")
        print(f"FID: {fid:.4f}")
        print(f"KID: {kid_mean:.4f} ± {kid_std:.4f}")

        # Save results
        results.append(
            {
                "cell_type": list(cell_type_to_label.keys())[cell_type],
                "cell_type_id": cell_type,
                "num_real": len(real_images_tensor),
                "num_generated": len(generated_images_tensor),
                "fid": fid,
                "kid_mean": kid_mean,
                "kid_std": kid_std,
            }
        )

    # Create a DataFrame and save results to CSV
    results_df = pd.DataFrame(results)
    output_file = f"cell_type_metrics_results_seed_{seed}.csv"
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print a summary table
    print("\nSummary of Cell Type Metrics:")
    print(results_df.to_string(index=False))
