import pandas as pd
import os
import ast
import shutil
import argparse
import sys


        
def extract_caption(text):
    """Extract caption from the model output text."""
    if text is None:
        return ""
    
    try:
        caption = str(text).lower().split('assistant')[-1].split('caption')[-1]
        caption = caption.replace(':', '').replace('**', '').replace('"', '').strip()
        caption = caption.replace("enhance ", "generate ")
        return caption
    except Exception as e:
        print(f"Error extracting caption: {str(e)}")
        return ""

def process_images(input_file, exp_name, image_dir, threshold_type='mean', threshold_value=None):
    """
    Process images based on their scores and copy them to a new directory.
    
    Args:
        input_file (str): Path to the input CSV file
        exp_name (str): Experiment name used for naming columns and output
        image_dir (str): Directory containing the images
        threshold_type (str): Type of threshold to apply ('mean', 'value')
        threshold_value (float): Specific threshold value (used if threshold_type is 'value')
    
    Returns:
        pd.DataFrame: Processed dataframe with selected images
    """
    try:
        # Load the dataset
        print(f"Loading data from {input_file}...")
        df = pd.read_csv(input_file)
        print(f"Loaded {len(df)} rows")
        
        # Apply threshold
        laion_cols1 = ['laion_score_0.75', 'laion_score_0.85', 'laion_score_0.95',]
        df['best_laion_col'] = df[laion_cols1].idxmax(axis=1)
        df['best_laion_score'] = df[laion_cols1].max(axis=1)
        def get_temp(col_name):
            parts = col_name.replace('laion_score', '').strip('_').split('_')
            temp = parts[0]
            return pd.Series([temp])
        df[['best_temp']] = df['best_laion_col'].apply(get_temp)
        mean=df['best_laion_score'].mean()
        
            
        df = df[df['best_laion_score'] > threshold]
        
        # Create output directory
        output_dir = f'diffusion_images_{exp_name}/'
        os.makedirs(output_dir, exist_ok=True)
        print(f"Created output directory: {output_dir}")
        
        # Process each row
        rows = []
        for idx, row in df.iterrows():
            if idx % 100 == 0:
                print(f"Processing row {idx}/{len(df)}...")
                
            temp = row['best_temp']
            if pd.isna(temp):
                print(f"Warning: Missing temperature for row {idx}, skipping")
                continue
                
            
            
            # Handle text column
            try:
                text = ast.literal_eval(row['caption'])[0] if pd.notna(row['caption']) else ""
            except (ValueError, SyntaxError) as e:
                print(f"Warning: Error parsing caption for row {idx}: {e}")
                text = str(row['caption'])
            
            # Construct image path
            image_path = os.path.join(image_dir, str(temp), row['filename'])
            
            # Check existence
            image_exists = os.path.exists(image_path)
            image_path_new = None
            
            # Copy image if it exists
            if image_exists:
                filename = os.path.basename(image_path)
                image_path_new = os.path.join(output_dir, filename)
                try:
                    shutil.copy(image_path, image_path_new)
                except Exception as e:
                    print(f"Failed to copy {image_path}: {e}")
                    image_path_new = None
            
            # Collect all fields
            rows.append({
                'filename': row['filename'],
                'text': text,
                'image_path': image_path_new
            })
        
        # Create result DataFrame
        result_df = pd.DataFrame(rows)
        return result_df, output_dir
    
    except Exception as e:
        print(f"Error processing images: {e}")
        sys.exit(1)

def main():
    """Parse arguments and execute the image processing pipeline."""
    parser = argparse.ArgumentParser(description='Process diffusion images based on scores')
    
    # Required arguments
    parser.add_argument('--input', '-i', required=True, help='Input CSV file with scores')
    parser.add_argument('--expname', '-e', required=True, help='Experiment name')
    parser.add_argument('--image-dir', '-d', required=True, help='Directory containing the images')
    
    
    
    args = parser.parse_args()
    
    # Validate arguments
    if not os.path.exists(args.input):
        print(f"Error: Input file '{args.input}' does not exist")
        return 1
        
    if not os.path.exists(args.image_dir):
        print(f"Error: Image directory '{args.image_dir}' does not exist")
        return 1
        

    
    # Set output filename
    output_file = args.output if args.output else f"{args.expname}_diff_data.csv"
    
    # Process the images
    print(f"Starting image processing for experiment: {args.expname}")
    result_df, output_dir = process_images(
        args.input, 
        args.expname, 
        args.image_dir,
    )
    
    # Save the results
    result_df.to_csv(output_file, index=False)
    print(f"✅ Saved {len(result_df)} rows to {output_file}")
    print(f"✅ Images copied to {output_dir}")
    
    return 0

if __name__ == "__main__":
    sys.exit(main())