import argparse
import pandas as pd
import numpy as np
import random
import ast

from pipeline.pipeline_utils import (get_question, get_answers, 
                            sample_categories, 
                            sample_user_strings, construct_question, 
                            IF_eval, match_lang)
from pipeline.models import generate_answer, construct_client
from pathlib import Path
from json import dumps
from tqdm import tqdm
from os import environ
from instructions.BaseInstruction import BaseInstruction
from instructions.instructions_utils import get_categories

def main(args):
    input_path = Path("preprocessed_answers") / args.lang / f"judge_{args.judge_model}" / f"{args.seed}.csv"
    questions = pd.read_csv(input_path)
    questions["instruction_string"] = questions["instruction_string"].apply(ast.literal_eval)

    # repeated requests in short time frame. Does nothing for openai calls
    answer_client = construct_client(args.code_model)

    IF_results = []
    save_path = Path("generations") / args.lang
    if args.followup:
        save_path = save_path / "followup"
    else:
        save_path = save_path / "predefined"

    save_path = save_path / f"judge_{args.judge_model}"
    save_path.mkdir(parents=True, exist_ok=True)
    save_path = save_path / f"{args.code_model}.csv"
    
    # Load existing results and avoid duplicates
    processed_id_category_pairs = set()
    if save_path.is_file():
        existing_df = pd.read_csv(save_path, index_col=0)
        # Get existing results to continue from where we left off
        IF_results = existing_df.to_dict('records')
        
        # Parse string representations back to lists for list columns
        for result in IF_results:
            for key in ['instruction_string', 'IF_answers', 'result']:
                if key in result and isinstance(result[key], str):
                    try:
                        result[key] = ast.literal_eval(result[key])
                    except (ValueError, SyntaxError):
                        # If parsing fails, keep as string
                        pass
        
        # Get (id, category) pairs that have already been processed
        processed_id_category_pairs = set(zip(existing_df['qid'], existing_df['category_id']))
        
        # Filter out already processed (id, category) combinations
        questions = questions[~questions.apply(lambda row: (row['id'], row['category']) in processed_id_category_pairs, axis=1)]
        print(f"Found {len(processed_id_category_pairs)} already processed (id, category) combinations. Continuing with {len(questions)} remaining questions.")
    else:
        IF_results = []
        print("No existing results file found. Starting fresh.")

    progress_bar = tqdm(total=len(questions), desc="Running Experiment")

    # one row in questions per qid x category combination
    for i, q in questions.iterrows():
        id = q["id"]
        past_answer = q["prev_ans"]
        curr_results = {"qid": id}
        curr_results["category_id"] = q["category"]
        curr_results["prev_ans"] = past_answer
        curr_results["IF_answers"] = ""
        curr_results["question_str"] = q["question_str"]
        curr_results["soln_source"] = q["soln_source"]
        instruct_string = q["instruction_string"]
        curr_results["instruction"] = [instruct_string]
        IF_question = construct_question(q["question_str"], instruct_string, past_answer, args.followup)
        curr_answer = generate_answer(args.code_model, IF_question, answer_client)

        # BIG DIFFERENCE not a list, a single val since thats how baseline does it
        curr_results["curr_answer"] = curr_answer
            
        # one result per category per question (78 * k results)
        IF_results.append(curr_results)

        # Save intermediate results after each question
        df = pd.DataFrame(IF_results)
        df.to_csv(save_path)
        
        progress_bar.update()
    
    progress_bar.close()

    # Final save
    df = pd.DataFrame(IF_results)
    df.to_csv(save_path)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--lang", type=str, required=True, help="lang for the questions")
    parser.add_argument("--code_model", type=str, required=True, help="model to use for IF code generations")
    parser.add_argument("--judge_model", type=str, required=True, help="model to use for IF code generations")
    parser.add_argument("--seed", type=int, required=True, default=42, help="seed file in preprocessed")
    parser.add_argument("--followup", action="store_true", help="Include past solution in the code gen prompt")

    main(parser.parse_args())