import argparse
import glob
import logging
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch_fidelity
import torchvision.transforms.functional as TF
from diffusers.models import AutoencoderKL
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from sc_perturb.dataset import CellDataModule, to_rgb
from sc_perturb.metrics_utils import calculate_metrics_from_scratch
from sc_perturb.models.sit import SiT_models
from sc_perturb.openphenom import OpenPhenomEncoder
from sc_perturb.utils.generation_utils import generate_perturbation_matched_samples
from sc_perturb.utils.utils import load_encoders
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from torchvision.utils import save_image
from tqdm import tqdm
from PIL import Image

if __name__ == "__main__":
    seed = 0
    cell_type_id = 1
    seed_everything(seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Path to generated images
    generated_base_path = "/mnt/pvc/REPA/fulltrain_model_74_all_perts_NEW"
    
    # Perturbations to process (same as in save_real_rgb.py)
    sampled_perturbation_ids = [1138, 1137, 1108, 1124]
    print(f"Sampled perturbation IDs: {sampled_perturbation_ids}")

    for i, pert_id in enumerate(sampled_perturbation_ids):
        print(f"\n\n{'='*80}")
        print(
            f"Processing perturbation ID: {pert_id}, {i+1}/{len(sampled_perturbation_ids)}"
        )
        print(f"{'='*80}")

        # Construct path for this perturbation
        generated_pert_path = os.path.join(
            generated_base_path, f"generated_p{pert_id}", f"cell_{cell_type_id}"
        )
        
        if not os.path.exists(generated_pert_path):
            print(f"Generated data directory not found: {generated_pert_path}")
            continue

        # Find all PNG files for this perturbation and cell type
        pattern = f"p{pert_id}_c{cell_type_id}_sample*.png"
        png_files = glob.glob(os.path.join(generated_pert_path, pattern))
        
        if not png_files:
            print(f"No generated PNG files found for perturbation {pert_id} and cell type {cell_type_id}")
            continue

        # Create output directory for this perturbation
        output_dir = f"/mnt/pvc/MorphGen/sc_perturb/evaluation/qualitative/rgb_samples/perturbation_{pert_id}"
        os.makedirs(output_dir, exist_ok=True)

        # Process and save generated images
        num_samples_to_save = min(5e6, len(png_files))  # Save up to 3 samples for testing
        print(f"Found {len(png_files)} generated images, saving {num_samples_to_save} for perturbation {pert_id}")

        # Sort files to ensure consistent ordering
        png_files.sort()

        for sample_idx in range(num_samples_to_save):
            try:
                file_path = png_files[sample_idx]
                
                # Extract sample number from filename for consistent naming
                filename = os.path.basename(file_path)
                sample_match = re.search(r'sample(\d+)', filename)
                original_sample_num = sample_match.group(1) if sample_match else str(sample_idx)

                print(f"  Processing generated sample {sample_idx}: {filename}")

                # Load the PNG image and convert to tensor
                pil_image = Image.open(file_path)
                
                # Convert PIL image to numpy array
                image_array = np.array(pil_image)
                
                # Handle different image formats
                if len(image_array.shape) == 2:  # Grayscale
                    # Convert grayscale to RGB by repeating channels
                    image_array = np.stack([image_array] * 3, axis=2)
                elif image_array.shape[2] == 4:  # RGBA
                    # Remove alpha channel
                    image_array = image_array[:, :, :3]
                
                # Convert to tensor and normalize to [0, 1]
                if image_array.dtype == np.uint8:
                    image_tensor = torch.from_numpy(image_array).float() / 255.0
                else:
                    image_tensor = torch.from_numpy(image_array).float()
                
                # Rearrange dimensions from (H, W, C) to (C, H, W)
                image_tensor = image_tensor.permute(2, 0, 1)
                
                print(f"    Original image shape: {image_tensor.shape}")
                
                # Resize from 3600x3600 to 512x512
                image_tensor = TF.resize(image_tensor, (512, 512))
                
                print(f"    Resized image shape: {image_tensor.shape}")
                
                # Ensure values are in [0, 1] range for saving
                image_tensor = torch.clamp(image_tensor, 0, 1)

                # Create output filename
                output_path = os.path.join(
                    output_dir, f"gen_pert{pert_id}_sample{original_sample_num}_rgb.png"
                )
                
                # Save using torchvision save_image
                save_image(image_tensor, output_path)

                print(f"    Saved: {output_path}")

            except Exception as e:
                print(f"    Error processing sample {sample_idx} ({file_path}): {e}")
                continue

    print(f"\n{'='*80}")
    print("Generated RGB image saving completed!")
    print(
        "Images saved to: /mnt/pvc/MorphGen/sc_perturb/evaluation/qualitative/rgb_samples/"
    )
    print(f"{'='*80}")