import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from diffusers.models import AutoencoderKL
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb
from sc_perturb.metrics_utils import calculate_metrics_from_scratch
from sc_perturb.models.sit import SiT_models
from sc_perturb.openphenom import OpenPhenomEncoder
from sc_perturb.utils.generation_utils import generate_perturbation_matched_samples
from sc_perturb.utils.utils import load_encoders
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from torchvision.utils import save_image
from tqdm import tqdm

if __name__ == "__main__":
    seed = 0
    MANUAL_GENERATION = False
    cell_type_id = 0
    seed_everything(seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # load yaml file
    filename = "/mnt/pvc/MorphGen/sc_perturb/cfgs/diffusion_sit_full.yaml"
    generated_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW/numpy_data"
    # Example numpy path: generated_path/p<pid>/p<pid>_c<c_id>_sample<sample_id>.npy
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Filter real images using CellDataModule
        real_filtered_dataset = datamodule.filter_samples(
            perturbation_id=pert_id,
            cell_type_id=cell_type_id,  # Use the cell_type_id variable
        )

        if real_filtered_dataset is None or len(real_filtered_dataset) == 0:
            print(
                f"No real data found for perturbation {pert_id} and cell type {cell_type_id}"
            )
            continue

        # Create output directory for this perturbation
        output_dir = f"/mnt/pvc/MorphGen/sc_perturb/evaluation/qualitative/rgb_samples/perturbation_{pert_id}"
        os.makedirs(output_dir, exist_ok=True)

        # Process and save a few real images (limit to 5 for testing)
        num_samples_to_save = min(5e6, len(real_filtered_dataset))
        print(f"Saving {num_samples_to_save} real images for perturbation {pert_id}")

        for sample_idx in range(num_samples_to_save):
            try:
                # Get the image tensor from dataset
                sample_data = real_filtered_dataset[sample_idx]
                if isinstance(sample_data, tuple):
                    image_tensor = sample_data[0]  # First element is usually the image
                    print(
                        f"    Dataset returned {len(sample_data)} items: {[type(x) for x in sample_data]}"
                    )
                else:
                    image_tensor = sample_data

                # Add batch dimension for to_rgb function
                image_batch = image_tensor.unsqueeze(0)  # Shape: (1, C, H, W)

                print(f"  Processing sample {sample_idx}: shape = {image_tensor.shape}")

                # Apply to_rgb function (one by one, not in batch)
                rgb_image = to_rgb(image_batch)  # Shape: (1, 3, H, W)

                # Remove batch dimension
                rgb_image = rgb_image.squeeze(0)  # Shape: (3, H, W)

                # Ensure values are in [0, 1] range for saving
                rgb_image = torch.clamp(rgb_image, 0, 1)

                # Save the RGB image using torchvision save_image
                output_path = os.path.join(
                    output_dir, f"real_pert{pert_id}_sample{sample_idx:03d}_rgb.png"
                )
                save_image(rgb_image, output_path)

                print(f"    Saved: {output_path}")

            except Exception as e:
                print(f"    Error processing sample {sample_idx}: {e}")
                continue

        # Also process a few generated images if they exist
        # Look for generated files for this perturbation and cell type
        pert_folder = f"p{pert_id}"
        pert_path = os.path.join(generated_path, pert_folder)

        if os.path.exists(pert_path):
            # Pattern to match cell type in filenames
            pattern = f"_c{cell_type_id}_sample"
            npy_files = glob.glob(os.path.join(pert_path, "*.npy"))
            filtered_files = [f for f in npy_files if pattern in f]

            if filtered_files:
                num_gen_samples = min(
                    3, len(filtered_files)
                )  # Save fewer generated samples
                print(
                    f"Saving {num_gen_samples} generated images for perturbation {pert_id}"
                )

                for gen_idx in range(num_gen_samples):
                    try:
                        file_path = filtered_files[gen_idx]

                        # Load generated image
                        gen_image = np.load(file_path)
                        gen_tensor = torch.from_numpy(gen_image).float()

                        # Add batch dimension
                        gen_batch = gen_tensor.unsqueeze(0)  # Shape: (1, C, H, W)

                        print(
                            f"  Processing generated sample {gen_idx}: shape = {gen_tensor.shape}"
                        )

                        # Apply to_rgb function
                        rgb_gen = to_rgb(gen_batch)  # Shape: (1, 3, H, W)

                        # Remove batch dimension
                        rgb_gen = rgb_gen.squeeze(0)  # Shape: (3, H, W)

                        # Ensure values are in [0, 1] range
                        rgb_gen = torch.clamp(rgb_gen, 0, 1)

                        # Save the RGB image using torchvision save_image
                        output_path = os.path.join(
                            output_dir, f"gen_pert{pert_id}_sample{gen_idx:03d}_rgb.png"
                        )
                        save_image(rgb_gen, output_path)

                        print(f"    Saved: {output_path}")

                    except Exception as e:
                        print(f"    Error processing generated sample {gen_idx}: {e}")
                        continue
            else:
                print(
                    f"No generated images found for perturbation {pert_id} and cell type {cell_type_id}"
                )
        else:
            print(f"Generated data directory not found: {pert_path}")

    print(f"\n{'='*80}")
    print("RGB image saving completed!")
    print(
        "Images saved to: /mnt/pvc/MorphGen/sc_perturb/evaluation/qualitative/rgb_samples/"
    )
    print(f"{'='*80}")
