import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from dataset import CellDataModule, to_rgb
from diffusers.models import AutoencoderKL
from metrics_utils import calculate_metrics_from_scratch
from models.sit import SiT_models
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from tqdm import tqdm
from train import generate_perturbation_matched_samples
from utils import load_encoders

# write a dummy custom dataset


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def find_generated_files_by_perturbation_and_celltype(
    generated_path, perturbation_id, cell_type_id
):
    """
    Find all generated numpy files for a specific perturbation ID and cell type ID.

    Args:
        generated_path: Path to the directory containing generated data
        perturbation_id: Perturbation ID to filter by
        cell_type_id: Cell type ID to filter by (0-3)

    Returns:
        List of file paths matching both the perturbation and cell type
    """
    pert_folder = f"p{perturbation_id}"
    pert_path = os.path.join(generated_path, pert_folder)

    if not os.path.exists(pert_path):
        return []

    # Pattern to match cell type in filenames (p<pid>_c<cell_type_id>_sample<sample_id>.npy)
    pattern = f"_c{cell_type_id}_sample"

    # Find all .npy files in the perturbation folder
    npy_files = glob.glob(os.path.join(pert_path, "*.npy"))

    # Filter files that match both the perturbation ID and cell type pattern
    filtered_files = [f for f in npy_files if pattern in f]

    return filtered_files


def load_numpy_files(file_paths, max_samples):
    """
    Load a random subset of numpy files into a torch tensor.

    Args:
        file_paths: List of numpy file paths to load
        max_samples: Maximum number of samples to load

    Returns:
        Torch tensor containing the loaded data
    """
    # Randomly sample file paths if there are more than max_samples
    if len(file_paths) > max_samples:
        file_paths = random.sample(file_paths, max_samples)

    # Load the numpy files
    data = []
    for file_path in tqdm(file_paths, desc="Loading numpy files"):
        try:
            img = np.load(file_path)
            data.append(torch.from_numpy(img).float())
        except Exception as e:
            print(f"Error loading {file_path}: {e}")

    # Stack into a single tensor
    return torch.stack(data) if data else None


def create_cell_type_metadata(num_samples=500, perturbation_id=1138, cell_type=1):
    """
    Creates a perturbation metadata list with all cell types set to 1.

    Args:
        num_samples: Number of metadata entries to create
        perturbation_id: The perturbation ID to use for all entries

    Returns:
        List of metadata dictionaries with cell_type_id set to 1
    """
    perturbation_metadata = []

    for i in range(num_samples):
        metadata_entry = {
            "perturbation_id": perturbation_id,
            "cell_type_id": 1,  # All cell types set to 1 as requested
            "is_generated": False,  # This is typically False for real data
        }
        perturbation_metadata.append(metadata_entry)

    print(f"Created perturbation metadata with {len(perturbation_metadata)} entries")
    print(f"All entries have cell_type_id set to 1")

    return perturbation_metadata


def augment_image(image, augmentation_type=None):
    """
    Apply augmentation to an image tensor.

    Args:
        image: Tensor image of shape [C, H, W]
        augmentation_type: Type of augmentation ('rotate', 'flip', 'unchanged').
                           If None, a random type will be chosen.

    Returns:
        Augmented image tensor
    """
    if augmentation_type is None:
        augmentation_type = random.choice(["rotate", "flip", "unchanged"])

    if augmentation_type == "rotate":
        # Random rotation by 90, 180, or 270 degrees
        angle = random.choice([90, 180, 270])
        return TF.rotate(image, angle)
    elif augmentation_type == "flip":
        # Random horizontal or vertical flip
        if random.random() > 0.5:
            return TF.hflip(image)
        else:
            return TF.vflip(image)
    else:  # 'unchanged'
        return image


if __name__ == "__main__":
    seed = 128
    MANUAL_GENERATION = True
    cell_type_id = 1
    seed_everything(seed)
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    # load yaml file
    filename = "diffusion_sit_full.yaml"
    # generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    # Example numpy path: generated_path/p<pid>/p<pid>_c<c_id>_sample<sample_id>.npy
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    if MANUAL_GENERATION:
        # ckpt = torch.load(
        #     "/mnt/pvc/REPA/exps/Trainfull-ophenomdeneme-b-enc8-in512/checkpoints/min_FID_74.74455261230469.pt",
        #     map_location="cpu",
        #     weights_only=False,
        # )
        # ckpt = torch.load(
        #     "/mnt/pvc/REPA/exps/NOREPA-ophenomdeneme-b-enc8-in512/checkpoints/min_AVG_FID_83.24544197.pt",
        #     map_location="cpu",
        #     weights_only=False,
        # )
        ckpt = torch.load(
            "/mnt/pvc/REPA/exps/OOD_ct1_p1137_ophenomdeneme-b-enc8-in512/checkpoints/min_AVG_FID_78.83929616.pt",
            map_location="cpu",
            weights_only=False,
        )
        enc_type = "openphenom-vit-b"
        resolution = 512
        latent_size = resolution // 8
        encoders, encoder_types, architectures = load_encoders(enc_type, device, 512)
        z_dims = (
            [encoder.embed_dim for encoder in encoders] if enc_type != "None" else [0]
        )
        block_kwargs = {"fused_attn": True, "qk_norm": False}
        model = SiT_models["SiT-XL/2"](
            input_size=latent_size,
            num_classes=1139,
            use_cfg=True,
            z_dims=z_dims,
            encoder_depth=8,
            in_channels=24,
            **block_kwargs,
        )
        model.load_state_dict(ckpt["model"])
        model = model.to(device)
        model.eval()
        vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
        latents_scale = (
            torch.tensor([0.18215, 0.18215, 0.18215, 0.18215])
            .view(1, 4, 1, 1)
            .to(device)
        )
        latents_bias = torch.tensor([0.0, 0.0, 0.0, 0.0]).view(1, 4, 1, 1).to(device)
        path_type = "linear"

    # Sample 4 random perturbation IDs out of 1138
    all_perturbation_ids = list(range(1, 1139))  # 1 to 1138
    sampled_perturbation_ids = random.sample(all_perturbation_ids, 50)
    sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    NUM_SAMPLES = 500
    # Iterate through each perturbation ID and calculate metrics
    results = []

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id, cell_type_id=cell_type_id
        )

        if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
            print(f"No real data found for perturbation ID {pert_id}")
            continue

        # Get real images
        real_images = [
            real_filtered_dataset[i][0] for i in range(len(real_filtered_dataset))
        ]
        print(f"Found {len(real_images)} real images for perturbation ID {pert_id}")

        # If we have fewer than NUM_SAMPLES real images, apply augmentations to reach NUM_SAMPLES
        if len(real_images) < NUM_SAMPLES:
            print(
                f"Only {len(real_images)} real samples available, applying augmentations to reach {NUM_SAMPLES}"
            )
            additional_samples_needed = NUM_SAMPLES - len(real_images)
            augmented_samples = []

            for _ in range(additional_samples_needed):
                # Select a random image from real images
                base_image = random.choice(real_images)
                # Apply random augmentation
                augmented_image = augment_image(base_image)
                augmented_samples.append(augmented_image)

            real_images.extend(augmented_samples)
            print(
                f"Added {len(augmented_samples)} augmented images to reach {len(real_images)} total samples"
            )
        # If we have more than NUM_SAMPLES, randomly sample from them
        elif len(real_images) > NUM_SAMPLES:
            indices = random.sample(range(len(real_images)), NUM_SAMPLES)
            real_images = [real_images[i] for i in indices]
            print(
                f"Sampled {NUM_SAMPLES} real images from {len(real_filtered_dataset)} total"
            )
        else:
            print(f"Using all {len(real_images)} available real images")

        # Convert to tensor
        real_images_tensor = torch.stack(real_images)

        if not MANUAL_GENERATION:
            # Find all generated files for this perturbation
            generated_files = find_generated_files_by_perturbation_and_celltype(
                generated_path,
                pert_id,
                cell_type_id,
            )
            print(
                f"Found {len(generated_files)} generated files for perturbation ID {pert_id}"
            )

            if not generated_files:
                print(f"No generated data found for perturbation ID {pert_id}")
                continue

            # Load generated images (sample up to NUM_SAMPLES)
            generated_images_tensor = load_numpy_files(
                generated_files, max_samples=NUM_SAMPLES
            )

            if generated_images_tensor is None:
                print(f"Failed to load generated images for perturbation ID {pert_id}")
                continue
        else:
            generated_images_tensor = []
            with torch.no_grad():
                batch_size = 8
                perturbation_metadata = create_cell_type_metadata(
                    num_samples=batch_size, perturbation_id=pert_id, cell_type=cell_type_id
                )
                # make the below for loop tqdm
                sample_counter = 0
                for i in tqdm(
                    range(NUM_SAMPLES // batch_size),
                    desc=f"Generating images for perturbation ID {pert_id}",
                ):
                    generated_images_tensor_batch, gen_images_metadata = (
                        generate_perturbation_matched_samples(
                            model,
                            pert_id, # This is the perturbation_id (pid)
                            perturbation_metadata, # This metadata should contain cell_type_id (cid)
                            vae,
                            latent_size,
                            resolution,
                            latents_bias,
                            latents_scale,
                            path_type,
                            device,
                        )
                    )
                    generated_images_tensor.append(generated_images_tensor_batch)

                    # Create directory for saving .npy files
                    save_dir = Path(f"./generated_ood_{seed}/p{pert_id}")
                    save_dir.mkdir(parents=True, exist_ok=True)

                    # Save each image in the batch
                    for idx in range(generated_images_tensor_batch.shape[0]):
                        img_to_save = generated_images_tensor_batch[idx].cpu().numpy()
                        # Attempt to get cell_type_id from metadata, fallback to global cell_type_id
                        current_cell_type_id = gen_images_metadata[idx].get('cell_type_id', cell_type_id)
                        
                        file_name = f"p_{pert_id}_c{current_cell_type_id}_sample_{sample_counter}.npy"
                        file_path = save_dir / file_name
                        np.save(file_path, img_to_save)
                        sample_counter += 1

            generated_images_tensor = torch.cat(generated_images_tensor, dim=0)
        print(f"Calculating metrics for perturbation ID {pert_id}")
        print(f"Real images shape: {real_images_tensor.shape}")
        print(f"Generated images shape: {generated_images_tensor.shape}")

        # Convert images to RGB format
        real_images_tensor = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in real_images_tensor]
        )
        generated_images_tensor = torch.stack(
            [to_rgb(img.cpu()[None]).squeeze(0) for img in generated_images_tensor]
        )
        real_uint8 = (real_images_tensor * 255).to(torch.uint8)
        fake_uint8 = (generated_images_tensor * 255).to(torch.uint8)

        # Create datasets
        real_images_dataset = CustomDataset(real_uint8)
        generated_images_dataset = CustomDataset(fake_uint8)

        # Calculate metrics
        metrics = torch_fidelity.calculate_metrics(
            input1=real_images_dataset,
            input2=generated_images_dataset,
            cuda=True,
            fid=True,
            kid=True,
            kid_subset_size=100,
            kid_subsets=100,
        )
        fid = metrics["frechet_inception_distance"]
        kid_mean = metrics["kernel_inception_distance_mean"]
        kid_std = metrics["kernel_inception_distance_std"]

        print(f"Perturbation ID: {pert_id}")
        print(f"FID: {fid:.4f}")
        print(f"KID: {kid_mean:.4f} ± {kid_std:.4f}")

        # Save results
        results.append(
            {
                "perturbation_id": pert_id,
                "num_real": len(real_images_tensor),
                "num_generated": len(generated_images_tensor),
                "fid": fid,
                "kid_mean": kid_mean,
                "kid_std": kid_std,
            }
        )

    # Create a DataFrame and save results to CSV
    results_df = pd.DataFrame(results)

    # Calculate average metrics
    avg_metrics = {
        "perturbation_id": "Average",
        "num_real": results_df["num_real"].mean(),
        "num_generated": results_df["num_generated"].mean(),
        "fid": results_df["fid"].mean(),
        "kid_mean": results_df["kid_mean"].mean(),
        "kid_std": results_df["kid_std"].mean(),
    }

    # Add average row to the DataFrame
    results_df = pd.concat([results_df, pd.DataFrame([avg_metrics])], ignore_index=True)

    output_file = f"OOD_perturbation_type_metrics_results_seed_{seed}.csv"
    # output_file = f"sanity_check_perturbation_type_metrics_results_seed_{seed}.csv"
    results_df.to_csv(output_file, index=False)
    print(f"\nResults saved to {output_file}")

    # Print a summary table
    print("\nSummary of Perturbation Metrics:")
    print(results_df.to_string(index=False))

    # Print the average metrics separately for clarity
    print("\nAverage Metrics:")
    print(f"Average FID: {avg_metrics['fid']:.4f}")
    print(f"Average KID: {avg_metrics['kid_mean']:.4f} ± {avg_metrics['kid_std']:.4f}")
