import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from diffusers.models import AutoencoderKL
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb
from sc_perturb.metrics_utils import calculate_metrics_from_scratch
from sc_perturb.models.sit import SiT_models
from sc_perturb.utils.generation_utils import generate_perturbation_matched_samples
from sc_perturb.utils.utils import load_encoders
from tqdm import tqdm

# write a dummy custom dataset


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Find all generated numpy files for a specific perturbation ID and cell type ID.

    Args:
        generated_path: Path to the directory containing generated data
        perturbation_id: Perturbation ID to filter by
        cell_type_id: Cell type ID to filter by (0-3)

    Returns:
        List of file paths matching both the perturbation and cell type
    """
    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)

    if not os.path.exists(pert_path):
        return []

    # Pattern to match cell type in filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Find all .npy files in the perturbation folder
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

    # Filter files that match both the perturbation ID and cell type pattern
    filtered_files = [f for f in npy_files if pattern in f]

    return filtered_files


def load_numpy_files(file_paths, max_samples):
    """
    Load a random subset of numpy files into a torch tensor.

    Args:
        file_paths: List of numpy file paths to load
        max_samples: Maximum number of samples to load

    Returns:
        Torch tensor containing the loaded data
    """
    # Randomly sample file paths if there are more than max_samples
    if len(file_paths) > max_samples:
        file_paths = random.sample(file_paths, max_samples)

    # Load the numpy files
    data = []
    for file_path in tqdm(file_paths, desc="Loading numpy files"):
        try:
            img = np.load(file_path)
            data.append(torch.from_numpy(img).float())
        except Exception as e:
            print(f"Error loading {file_path}: {e}")

    # Stack into a single tensor
    return torch.stack(data) if data else None


def create_cell_type_metadata(num_samples=500, perturbation_id=1138, cell_type=1):
    """
    Creates a perturbation metadata list with all cell types set to 1.

    Args:
        num_samples: Number of metadata entries to create
        perturbation_id: The perturbation ID to use for all entries

    Returns:
        List of metadata dictionaries with cell_type_id set to 1
    """
    perturbation_metadata = []

    for i in range(num_samples):
        metadata_entry = {
            "perturbation_id": perturbation_id,
            "cell_type_id": 1,  # All cell types set to 1 as requested
            "is_generated": False,  # This is typically False for real data
        }
        perturbation_metadata.append(metadata_entry)

    print(f"Created perturbation metadata with {len(perturbation_metadata)} entries")
    print(f"All entries have cell_type_id set to 1")

    return perturbation_metadata


def augment_image(image, augmentation_type=None):
    """
    Apply augmentation to an image tensor.

    Args:
        image: Tensor image of shape [C, H, W]
        augmentation_type: Type of augmentation ('rotate', 'flip', 'unchanged').
                           If None, a random type will be chosen.

    Returns:
        Augmented image tensor
    """
    if augmentation_type is None:
        augmentation_type = random.choice(["rotate", "flip", "unchanged"])

    if augmentation_type == "rotate":
        # Random rotation by 90, 180, or 270 degrees
        angle = random.choice([90, 180, 270])
        return TF.rotate(image, angle)
    elif augmentation_type == "flip":
        # Random horizontal or vertical flip
        if random.random() > 0.5:
            return TF.hflip(image)
        else:
            return TF.vflip(image)
    else:  # 'unchanged'
        return image


if __name__ == "__main__":
    seed = 1337
    MANUAL_GENERATION = True
    cell_type_id = 1
    seed_everything(seed)
    device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
    # load yaml file
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    # Example numpy path: generated_path/p<pid>/p<pid>_c<c_id>_sample<sample_id>.npy
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    if MANUAL_GENERATION:
        ckpt = torch.load(
            "/mnt/pvc/REPA/exps/Plain-ophenomdeneme-b-enc8-in512/checkpoints/min_FID_75.060546875.pt",
            map_location="cpu",
            weights_only=False,
        )
        enc_type = "openphenom-vit-b"
        resolution = 512
        latent_size = resolution // 8
        encoders, encoder_types, architectures = load_encoders(enc_type, device, 512)
        z_dims = (
            [encoder.embed_dim for encoder in encoders] if enc_type != "None" else [0]
        )
        block_kwargs = {"fused_attn": True, "qk_norm": False}
        model = SiT_models["SiT-XL/2"](
            input_size=latent_size,
            num_classes=1139,
            use_cfg=True,
            z_dims=z_dims,
            encoder_depth=8,
            in_channels=24,
            **block_kwargs,
        )
        model.load_state_dict(ckpt["model"])
        model = model.to(device)
        model.eval()
        vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
        latents_scale = (
            torch.tensor([0.18215, 0.18215, 0.18215, 0.18215])
            .view(1, 4, 1, 1)
            .to(device)
        )
        latents_bias = torch.tensor([0.0, 0.0, 0.0, 0.0]).view(1, 4, 1, 1).to(device)
        path_type = "linear"

    # Sample 4 random perturbation IDs out of 1138
    all_perturbation_ids = list(range(1, 1139))  # 1 to 1138
    sampled_perturbation_ids = random.sample(all_perturbation_ids, 50)
    # sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    NUM_SAMPLES = 500
    # Iterate through each perturbation ID and calculate metrics
    results = []

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(perturbation_id=pert_id)

        if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
            print(f"No real data found for perturbation ID {pert_id}")
            continue

        # Get real images
        real_images = [
            real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
        ]
        print(f"Found {len(real_images)} real images for perturbation ID {pert_id}")

        # pick half of the real images
        real_images = random.sample(real_images, len(real_images) // 2)
        NUM_SAMPLES = len(real_images)

        # If we have fewer than NUM_SAMPLES real images, apply augmentations to reach NUM_SAMPLES
        if len(real_images) < NUM_SAMPLES:
            print(
                f"Only {len(real_images)} real samples available, applying augmentations to reach {NUM_SAMPLES}"
            )
            additional_samples_needed = NUM_SAMPLES - len(real_images)
            augmented_samples = []

            for _ in range(additional_samples_needed):
                # Select a random image from real images
                base_image = random.choice(real_images)
                # Apply random augmentation
                augmented_image = augment_image(base_image)
                augmented_samples.append(augmented_image)

            real_images.extend(augmented_samples)
            print(
                f"Added {len(augmented_samples)} augmented images to reach {len(real_images)} total samples"
            )
        # If we have more than NUM_SAMPLES, randomly sample from them
        elif len(real_images) > NUM_SAMPLES:
            indices = random.sample(range(len(real_images)), NUM_SAMPLES)
            real_images = [real_images[i] for i in indices]
            print(
                f"Sampled {NUM_SAMPLES} real images from {len(real_filtered_dataset)} total"
            )
        else:
            print(f"Using all {len(real_images)} available real images")

        # Convert to tensor
        real_images_tensor = torch.stack(real_images)

        if not MANUAL_GENERATION:
            # Find all generated files for this perturbation
            cell_types = [0, 1, 2, 3]
            generated_files = []
            for cell_type_id in cell_types:
                generated_files_ct = find_generated_files_by_perturbation_and_celltype(
                    generated_path,
                    pert_id,
                    cell_type_id,
                )
                generated_files.extend(generated_files_ct)
            print(
                f"Found {len(generated_files)} generated files for perturbation ID {pert_id}"
            )

            if not generated_files:
                print(f"No generated data found for perturbation ID {pert_id}")
                continue

            # Load generated images (sample up to NUM_SAMPLES)
            generated_images_tensor = load_numpy_files(
                generated_files, max_samples=NUM_SAMPLES
            )

            if generated_images_tensor is None:
                print(f"Failed to load generated images for perturbation ID {pert_id}")
                continue
        else:
            generated_images_tensor = []
            with torch.no_grad():
                batch_size = 8
                perturbation_metadata = create_cell_type_metadata(
                    num_samples=batch_size, perturbation_id=pert_id
                )
                # make the below for loop tqdm

                for i in tqdm(
                    range(NUM_SAMPLES // batch_size),
                    desc=f"Generating images for perturbation ID {pert_id}",
                ):
                    generated_images_tensor_batch, gen_images_metadata = (
                        generate_perturbation_matched_samples(
                            model,
                            pert_id,
                            perturbation_metadata,
                            vae,
                            latent_size,
                            resolution,
                            latents_bias,
                            latents_scale,
                            path_type,
                            device,
                        )
                    )
                    generated_images_tensor.append(generated_images_tensor_batch)
            generated_images_tensor = torch.cat(generated_images_tensor, dim=0)
        print(f"Calculating metrics for perturbation ID {pert_id}")
        print(f"Real images shape: {real_images_tensor.shape}")
        print(f"Generated images shape: {generated_images_tensor.shape}")
        # calculate channelwise metrics
        num_channels = 6
        channel_results = []

        for ch in range(num_channels):
            real_images_tensor_ch = real_images_tensor[:, ch, :, :]
            # stack three times to create RGB-like images for metrics calculation
            real_images_tensor_ch = torch.stack(
                [real_images_tensor_ch, real_images_tensor_ch, real_images_tensor_ch],
                dim=1,
            )
            generated_images_tensor_ch = generated_images_tensor[:, ch, :, :]
            # stack three times
            generated_images_tensor_ch = torch.stack(
                [
                    generated_images_tensor_ch,
                    generated_images_tensor_ch,
                    generated_images_tensor_ch,
                ],
                dim=1,
            )
            real_ch_uint8 = (real_images_tensor_ch * 255).to(torch.uint8)
            generated_ch_uint8 = (generated_images_tensor_ch * 255).to(torch.uint8)
            real_ch_dataset = CustomDataset(real_ch_uint8)
            generated_ch_dataset = CustomDataset(generated_ch_uint8)
            metrics_ch = torch_fidelity.calculate_metrics(
                input1=real_ch_dataset,
                input2=generated_ch_dataset,
                cuda=True,
                fid=True,
                kid=False,
                # kid=True,
                # kid_subset_size=100,
                # kid_subsets=100,
            )
            metrics_ch["kernel_inception_distance_mean"] = 0.0
            metrics_ch["kernel_inception_distance_std"] = 0.0
            ch_fid = metrics_ch["frechet_inception_distance"]
            ch_kid_mean = metrics_ch["kernel_inception_distance_mean"]
            ch_kid_std = metrics_ch["kernel_inception_distance_std"]

            print(f"Channel {ch} FID: {ch_fid:.4f}")
            print(f"Channel {ch} KID: {ch_kid_mean:.4f} ± {ch_kid_std:.4f}")

            # Save channel-wise results
            channel_results.append(
                {
                    "perturbation_id": pert_id,
                    "channel": ch,
                    "fid": ch_fid,
                    "kid_mean": ch_kid_mean,
                    "kid_std": ch_kid_std,
                }
            )
        # Convert images to RGB format
        real_images_tensor = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in real_images_tensor]
        )
        generated_images_tensor = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in generated_images_tensor]
        )
        real_uint8 = (real_images_tensor * 255).to(torch.uint8)
        fake_uint8 = (generated_images_tensor * 255).to(torch.uint8)

        # Create datasets
        real_images_dataset = CustomDataset(real_uint8)
        generated_images_dataset = CustomDataset(fake_uint8)

        # Calculate metrics
        metrics = torch_fidelity.calculate_metrics(
            input1=real_images_dataset,
            input2=generated_images_dataset,
            cuda=True,
            fid=True,
            kid=False,
            # kid=True,
            # kid_subset_size=100,
            # kid_subsets=100,
        )
        metrics["kernel_inception_distance_mean"] = 0.0
        metrics["kernel_inception_distance_std"] = 0.0
        fid = metrics["frechet_inception_distance"]
        kid_mean = metrics["kernel_inception_distance_mean"]
        kid_std = metrics["kernel_inception_distance_std"]

        print(f"Perturbation ID: {pert_id}")
        print(f"FID: {fid:.4f}")
        print(f"KID: {kid_mean:.4f} ± {kid_std:.4f}")

        # Organize channel metrics for this perturbation
        ch_metrics = {}
        for ch in range(num_channels):
            ch_data = next(
                (
                    item
                    for item in channel_results
                    if item["perturbation_id"] == pert_id and item["channel"] == ch
                ),
                None,
            )
            if ch_data:
                ch_metrics[f"channel_{ch}_fid"] = ch_data["fid"]
                ch_metrics[f"channel_{ch}_kid_mean"] = ch_data["kid_mean"]
                ch_metrics[f"channel_{ch}_kid_std"] = ch_data["kid_std"]

        # Save combined results with channel metrics
        result_entry = {
            "perturbation_id": pert_id,
            "num_real": len(real_images_tensor),
            "num_generated": len(generated_images_tensor),
            "fid": fid,
            "kid_mean": kid_mean,
            "kid_std": kid_std,
            **ch_metrics,  # Include all channel metrics
        }
        results.append(result_entry)

    # Create a DataFrame and save results to CSV
    results_df = pd.DataFrame(results)

    # Calculate average metrics
    avg_metrics = {
        "perturbation_id": "Average",
        "num_real": results_df["num_real"].mean(),
        "num_generated": results_df["num_generated"].mean(),
        "fid": results_df["fid"].mean(),
        "kid_mean": results_df["kid_mean"].mean(),
        "kid_std": results_df["kid_std"].mean(),
    }

    # Add channel averages to the overall average metrics
    for ch in range(num_channels):
        if f"channel_{ch}_fid" in results_df.columns:
            avg_metrics[f"channel_{ch}_fid"] = results_df[f"channel_{ch}_fid"].mean()
            avg_metrics[f"channel_{ch}_kid_mean"] = results_df[
                f"channel_{ch}_kid_mean"
            ].mean()
            avg_metrics[f"channel_{ch}_kid_std"] = results_df[
                f"channel_{ch}_kid_std"
            ].mean()

    # Add average row to the DataFrame
    results_df = pd.concat([results_df, pd.DataFrame([avg_metrics])], ignore_index=True)

    output_file = f"perturbation_type_metrics_results_all_cell_types_seed_{seed}.csv"
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print a summary table
    print("\nSummary of Perturbation Metrics:")
    print(results_df.to_string(index=False))

    # Print the average metrics separately for clarity
    print("\nAverage Metrics:")
    print(f"Average FID: {avg_metrics['fid']:.4f}")
    print(f"Average KID: {avg_metrics['kid_mean']:.4f} ± {avg_metrics['kid_std']:.4f}")

    # Create a DataFrame for channel results and save to CSV
    channel_results_df = pd.DataFrame(channel_results)

    # Calculate average metrics per channel
    channel_avg_metrics = []
    for ch in range(num_channels):
        ch_data = channel_results_df[channel_results_df["channel"] == ch]
        channel_avg_metrics.append(
            {
                "perturbation_id": "Average",
                "channel": ch,
                "fid": ch_data["fid"].mean(),
                "kid_mean": ch_data["kid_mean"].mean(),
                "kid_std": ch_data["kid_std"].mean(),
            }
        )

    # Add average rows to the channel DataFrame
    channel_results_df = pd.concat(
        [channel_results_df, pd.DataFrame(channel_avg_metrics)], ignore_index=True
    )

    # Save channel-wise results
    channel_output_file = f"perturbation_type_channel_metrics_results_seed_{seed}.csv"
    channel_results_df.to_csv(channel_output_file, index=False)
    print(f"\nChannel-wise results saved to {channel_output_file}")

    # Print a summary of channel-wise metrics
    print("\nSummary of Channel-wise Metrics:")
    for ch in range(num_channels):
        ch_avg = channel_avg_metrics[ch]
        print(
            f"Channel {ch} - Average FID: {ch_avg['fid']:.4f}, Average KID: {ch_avg['kid_mean']:.4f} ± {ch_avg['kid_std']:.4f}"
        )
