import argparse
import glob
import logging
import os
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import textwrap
import re
from collections import defaultdict
from typing import List, Tuple
from dataset import CellDataModule
from omegaconf import OmegaConf


if __name__ == "__main__":
    # load yaml file
    filename = "diffusion_sit_full.yaml"
    # load yaml
    config = OmegaConf.load(filename)
    datamodule = CellDataModule(config)

    cell_type_to_label = {
        "HEPG2": 0,
        "HUVEC": 1,
        "RPE": 2,
        "U2OS": 3,
    }
    
    # Create main directory for original images
    original_imgs_dir = "./original_imgs"
    os.makedirs(original_imgs_dir, exist_ok=True)
    
    pid_list = [1138, 1137, 1108, 1124]
    
    # Reverse mapping from label to cell type for easier access
    label_to_cell_type = {v: k for k, v in cell_type_to_label.items()}
    
    for pid in pid_list:
        # Create directory for this perturbation ID
        pert_dir = os.path.join(original_imgs_dir, f"p{pid}")
        os.makedirs(pert_dir, exist_ok=True)
        
        # Filter metadata for this perturbation ID
        filtered_metadata = datamodule.filter_metadata(perturbation_id=pid)
        numpy_paths = filtered_metadata["numpy_path"].tolist()
        
        # Counter for each cell type to number samples
        cell_type_counters = {cell_id: 0 for cell_id in cell_type_to_label.values()}
        
        for numpy_path in numpy_paths:
            # Extract cell type from the path - path follows pattern like:
            # '/mnt/pvc/AutoSync/data/rxrx1/numpy_images/U2OS-03/Plate4/B02_s1.npy'
            path_parts = numpy_path.split('/')
            cell_type_part = path_parts[-3]  # e.g., 'U2OS-03'
            cell_type = cell_type_part.split('-')[0]  # e.g., 'U2OS'
            
            # Get the cell ID from the cell type
            cell_id = cell_type_to_label[cell_type]
            
            # Increment counter for this cell ID
            cell_type_counters[cell_id] += 1
            count = cell_type_counters[cell_id]
            
            # Define destination path with new naming scheme using cell ID
            dest_filename = f"p{pid}_c{cell_id}_sample{count}.npy"
            dest_path = os.path.join(pert_dir, dest_filename)
            
            # Copy the numpy file
            shutil.copy(numpy_path, dest_path)
            print(f"Copied {numpy_path} to {dest_path}")
    
    print(f"All images saved to {original_imgs_dir}")