import pandas as pd
import string
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import json
import argparse
import os

def extract_caption(text):
    """Extract caption from the model output text."""
    try:
        caption = text.lower().split('assistant')[-1].split('caption')[-1]
        caption = caption.replace(':', '').replace('**', '').replace('"', '').strip()
        caption = caption.replace("enhance ", "generate ")
        return caption
    except Exception as e:
        print(f"Error extracting caption: {str(e)}")
        return ""

def extract_plan(text):
    """Extract plan from the model output text."""
    try:
        text = text.lower().split('assistant')[-1]
        if 'caption' in text:
            return text.split('caption')[-2].replace('**', '').replace('"', '').strip()
        else:
            return text.replace('**', '').replace('"', '').strip()
    except Exception as e:
        print(f"Error extracting plan: {str(e)}")
        print(text)
        return ""

def process_dataframe(df, expname):
    """Process the dataframe and generate the output files."""
    # Add caption and plan columns for each temperature
    for temp in [0.75, 0.85, 0.95]:
        df[f"plan_{expname}_{temp}"] = None
        df[f"caption_{expname}_{temp}"] = None
        for i in range(0, len(df)):
            df.at[i, f"plan_{expname}_{temp}"] = extract_plan(df.iloc[i][f"output_temp_{expname}_{temp}"])
            df.at[i, f"caption_{expname}_{temp}"] = extract_caption(df.iloc[i][f"output_temp_{expname}_{temp}"])

    # Compute max and min LAION scores
    df['max_laion_score'] = df[[f'laion_score_{expname}_0.75', f'laion_score_{expname}_0.85', f'laion_score_{expname}_0.95']].max(axis=1)
    df['min_laion_score'] = df[[f'laion_score_{expname}_0.75', f'laion_score_{expname}_0.85', f'laion_score_{expname}_0.95']].min(axis=1)

    # Compute mean difference
    mean_diff = float(df["max_laion_score"].mean() - df["min_laion_score"].mean())
    count = (df["max_laion_score"] - df["min_laion_score"] > mean_diff).sum()

    print(f"Number of pairs with significant difference: {count}")

    # Extract the temp values where laion_score is max and min
    df["chosen_temp"] = (
        df[[f"laion_score_{expname}_0.75", f"laion_score_{expname}_0.85", f"laion_score_{expname}_0.95"]]
        .idxmax(axis=1)
        .str.extract(r"(\d+\.\d+)")
    )

    df["rejected_temp"] = (
        df[[f"laion_score_{expname}_0.75", f"laion_score_{expname}_0.85", f"laion_score_{expname}_0.95"]]
        .idxmin(axis=1)
        .str.extract(r"(\d+\.\d+)")
    )

    # Assign corresponding captions based on chosen_temp and rejected_temp
    df["chosen"] = df.apply(lambda row: row[f'caption_{expname}_{row["chosen_temp"]}'], axis=1)
    df["rejected"] = df.apply(lambda row: row[f'caption_{expname}_{row["rejected_temp"]}'], axis=1)

    # Assign chosen and rejected plans
    df["chosen_plan"] = df.apply(lambda row: row[f'plan_{expname}_{row["chosen_temp"]}'], axis=1)
    df["rejected_plan"] = df.apply(lambda row: row[f'plan_{expname}_{row["rejected_temp"]}'], axis=1)

    return df, mean_diff, count

def create_dpopairs(df, expname, mean_diff):
    """Create DPO pairs JSONL file."""
    new_data = []
    import ast
    count = 0
    
    for i, row in df.iterrows():
        if row['max_laion_score'] - row['min_laion_score'] > mean_diff:
            count += 1
            image_path = "base_flickr_images_train/" + row["filename"]
            caption = ast.literal_eval(row['caption'])[0] if isinstance(row['caption'], str) else row['caption'][0]
            prompt = f"""Think step by step and analyze the image to iteratively enhance it. Generate a 60-word diffusion model prompt to generate an enhanced version of the image with improved aesthetics. Ensure all original details are 
                         preserved, and the enhancements match the style and essence of the original image and should be better in aesthetics.

                        The base caption for the image is: {caption}. Make sure the prompt is always present, if you need to shorten the prompt, that is fine.

                        Output format:
                        Step1_Plan: [Specific plan for step-1]

                        Step2_Plan: [Specific plan for step-2, building on step-1]

                        Step3_Plan: [Specific plan for step-3, building on step-1 & step-2] ...

                        Caption: [Generated prompt]

                        Make sure to generate a plan of upto six steps and the prompt. The prompt is super important.
                    """
            
            new_data.append({
                'image_path': image_path, 
                'prompt': prompt, 
                'chosen': row['chosen_plan'] + "\nCaption:" + row['chosen'], 
                'rejected': row['rejected_plan'] + "\nCaption:" + row['rejected'],
            })

    print(f"Created {count} DPO pairs")
    
    # Determine output path
    output_path = os.path.join(f'data_formats/flickr_train_{expname}_dpopairs.jsonl')
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in new_data:
            json_line = json.dumps(item, ensure_ascii=False)
            f.write(json_line + '\n')
    
    print(f"Saved DPO pairs to {output_path}")
    return count

def main():
    parser = argparse.ArgumentParser(description='Process caption data and create DPO pairs.')
    parser.add_argument('--input_csv', required=True, help='Path to the input CSV file')
    parser.add_argument('--exp_name', required=True, help='Experiment name')
        
    args = parser.parse_args()
    
    # Print arguments for debugging
    print(f"Input CSV: {args.input_csv}")
    print(f"Experiment name: {args.exp_name}")
    
    
    # Check if input file exists
    if not os.path.exists(args.input_csv):
        print(f"Error: Input file {args.input_csv} does not exist!")
        return
    
    
    # Read input CSV
    try:
        df = pd.read_csv(args.input_csv)
        print(f"Successfully read CSV with {len(df)} rows")
    except Exception as e:
        print(f"Error reading CSV: {str(e)}")
        return
    
    # Process dataframe
    try:
        processed_df, mean_diff, count = process_dataframe(df, args.exp_name)
        print(f"Processed dataframe with mean difference of {mean_diff}")
    except Exception as e:
        print(f"Error processing dataframe: {str(e)}")
        return
    
    # Create DPO pairs
    try:
        num_pairs = create_dpopairs(processed_df, args.exp_name, mean_diff)
        print(f"Successfully created {num_pairs} DPO pairs")
    except Exception as e:
        print(f"Error creating DPO pairs: {str(e)}")
        return

if __name__ == "__main__":
    main()