import openai 
import pandas as pd
from tqdm import tqdm
from json import load as load_json_file, dump as json_dump

import argparse

def gpt_answer(prompt):
    completion = openai.chat.completions.create(
        model="o4-mini-2025-04-16",
        messages=[
            {"role": "user", "content": prompt}
        ],
    )
    output_text = completion.choices[0].message.content
    ans = output_text.split("<answer>")[-1].split("</answer>")[0].strip()
    return ans

def main(args):
    with open("prompts/initial_coding_direct.txt", "r") as f:
        initial_code_prompt = f.read()
    with open(args.examples_file, "r") as f:
        examples = f.read()
    
    all_df = pd.read_csv(args.input_csv)

    # remove holdout ids
    # with open(f"holdout/{args.lang}_holdout.json", "r") as f:
    #     ids_to_avoid = load_json_file(f)
    # filtered_df = all_df[~all_df["instruction_id"].isin(ids_to_avoid)]

    # Load existing categories if available, otherwise use empty string
    # try:
    #     with open("initial_categories.json") as f:
    #         categories = load_json_file(f)
    # except FileNotFoundError:
    #     raise ValueError("can't find initial categories")
    categories = []

    # Check if 'llm' column exists and find starting index
    start_row_idx = 0
    if 'llm' in all_df.columns:
        # Find the first row where llm column is NaN or empty
        llm_missing = all_df['llm'].isna() | (all_df['llm'] == '') | (all_df['llm'].astype(str).str.strip() == '')
        if llm_missing.any():
            start_row_idx = llm_missing.idxmax()  # Get first True index
            print(f"Found existing 'llm' column. Resuming from row {start_row_idx}")
            
            # Extract existing categories from already processed rows
            existing_categories = all_df.loc[:start_row_idx-1, 'llm'].dropna()
            existing_categories = existing_categories[existing_categories != '']
            existing_categories = existing_categories[existing_categories.astype(str).str.strip() != '']
            categories = list(existing_categories.unique())
            print(f"Found {len(categories)} existing categories: {categories}")
        else:
            print("All rows already have 'llm' values. Nothing to process.")
            return
    else:
        # Add 'llm' column if it doesn't exist
        all_df['llm'] = ''
        print("Added new 'llm' column to dataframe")

    # Calculate remaining rows to process
    remaining_rows = len(all_df) - start_row_idx
    if remaining_rows == 0:
        print("No rows to process.")
        return
    
    print(f"Processing {remaining_rows} rows starting from index {start_row_idx}")

    # Process 5 rows at a time, starting from the determined index
    batch_size = 5
    remaining_batches = (remaining_rows + batch_size - 1) // batch_size  # Ceiling division
    
    for batch_idx in tqdm(range(remaining_batches), desc="Processing batches"):
        # Calculate actual dataframe indices
        actual_start_idx = start_row_idx + (batch_idx * batch_size)
        actual_end_idx = min(actual_start_idx + batch_size, len(all_df))
        batch_rows = all_df.iloc[actual_start_idx:actual_end_idx]
        
        # Create numbered instruction list for this batch
        instruction_list = []
        for i, (_, row) in enumerate(batch_rows.iterrows(), 1):
            instruction_list.append(f"{i}: {row['instruction']}")
        
        instruction_list_str = "\n".join(instruction_list)
        
        # Format prompt with batch of instructions
        if len(categories) == 0:
            curr_prompt = initial_code_prompt.format(
                categories="No categories yet; create a categories",
                instruction_list=instruction_list_str,
                examples=examples
            )
        else:
            curr_prompt = initial_code_prompt.format(
                categories=categories,
                instruction_list=instruction_list_str,
                examples=examples
            )
        # Get response from API
        response = gpt_answer(curr_prompt)
        
        # Parse response to extract individual categories
        response_lines = response.strip().split('\n')
        
        # Store results back to dataframe
        for i, (idx, row) in enumerate(batch_rows.iterrows()):
            # Look for the corresponding numbered response
            category = ""
            for line in response_lines:
                line = line.strip()
                if line.startswith(f"{i+1}:"):
                    category = line.split(":", 1)[1].strip()
                    break
            
            if category not in categories:
                categories.append(category)

            all_df.at[idx, 'llm'] = category

        # save intermediate cats   
        if batch_idx % 3 == 0:
            all_df.to_csv(args.output_csv)
        
    all_df.to_csv(args.output_csv)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_csv", type=str, required=True, help="input csv file name with instruction_id and instructions")
    parser.add_argument("--examples_file", type=str, required=True, help="file containing the examplars for coding")
    parser.add_argument("--output_csv", type=str, required=True, help="input csv file name with instruction_id and instructions")
    args = parser.parse_args()
    main(args)