import argparse
import pandas as pd
import numpy as np
import random
import ast

from pipeline.pipeline_utils import (get_question, get_answers, 
                            sample_categories, 
                            sample_user_strings, construct_question, 
                            IF_eval, match_lang)
from pipeline.models import generate_answer, construct_client
from pathlib import Path
from json import dumps
from tqdm import tqdm
from os import environ
from instructions.BaseInstruction import BaseInstruction
from instructions.instructions_utils import get_categories

def main(args):
    lang = match_lang(args.lang)
    input_path = Path("preprocessed_answers") / args.lang / f"judge_{args.judge_model}" / f"{args.seed}.csv"
    questions = pd.read_csv(input_path)
    questions["instruction_string"] = questions["instruction_string"].apply(ast.literal_eval)
    all_categories = get_categories(lang)

    # repeated requests in short time frame. Does nothing for openai calls
    judge_client = construct_client(args.judge_model)
    answer_client = construct_client(args.code_model)
    
    # Initialize the singleton judge client for all instruction instances
    BaseInstruction.set_judge_client(judge_client)

    IF_results = []
    save_path = Path("results") / args.lang
    if args.followup:
        save_path = save_path / "followup"
    else:
        save_path = save_path / "predefined"

    save_path = save_path / f"judge_{args.judge_model}"
    save_path.mkdir(parents=True, exist_ok=True)
    save_path = save_path / f"{args.code_model}.csv"
    
    # Load existing results and avoid duplicates
    processed_id_category_pairs = set()
    if save_path.is_file():
        existing_df = pd.read_csv(save_path, index_col=0)
        # Get existing results to continue from where we left off
        IF_results = existing_df.to_dict('records')
        
        # Parse string representations back to lists for list columns
        for result in IF_results:
            for key in ['instruction_string', 'IF_answers', 'result']:
                if key in result and isinstance(result[key], str):
                    try:
                        result[key] = ast.literal_eval(result[key])
                    except (ValueError, SyntaxError):
                        # If parsing fails, keep as string
                        pass
        
        # Get (id, category) pairs that have already been processed
        processed_id_category_pairs = set(zip(existing_df['id'], existing_df['category']))
        
        # Filter out already processed (id, category) combinations
        questions = questions[~questions.apply(lambda row: (row['id'], row['category']) in processed_id_category_pairs, axis=1)]
        print(f"Found {len(processed_id_category_pairs)} already processed (id, category) combinations. Continuing with {len(questions)} remaining questions.")
    else:
        IF_results = []
        print("No existing results file found. Starting fresh.")

    progress_bar = tqdm(total=len(questions), desc="Running Experiment")

    # one row in questions per qid x category combination
    for i, q in questions.iterrows():
        id = q["id"]
        past_answer = q["prev_ans"]
        curr_results = {"id": id}
        curr_results["category"] = q["category"]
        curr_results["prev_ans"] = past_answer
        curr_results["instruction_string"] = []
        curr_results["IF_answers"] = []
        curr_results["result"] = []
        curr_results["verify_cot"] = []
       # print(f'Using categories {[cat.instruction_id for cat in applicable_categories]}')
        for user_instruct in q["instruction_string"]:
            curr_results["instruction_string"].append(user_instruct)

            IF_question = construct_question(q["question_str"], user_instruct, past_answer, args.followup)
            curr_answer = generate_answer(args.code_model, IF_question, answer_client)

            curr_cat = None
            for category in all_categories:
                if category.instruction_id == q["category"]:
                    curr_cat = category
            if not curr_cat:
                raise ValueError(f"No matching category for category {q['category']}")


            cot, result = IF_eval(curr_cat, lang, curr_answer, past_answer)

            curr_results["IF_answers"].append(curr_answer)
            curr_results["result"].append(result)
            curr_results["verify_cot"].append(cot)
            
        # one result per category per question (78 * k results)
        IF_results.append(curr_results)

        # Save intermediate results after each question
        df = pd.DataFrame(IF_results)
        df.to_csv(save_path)
        
        progress_bar.update()
    
    progress_bar.close()

    # Final save
    df = pd.DataFrame(IF_results)
    df.to_csv(save_path)
    




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--lang", type=str, required=True, help="lang for the questions")
    parser.add_argument("--code_model", type=str, required=True, help="model to use for IF code generations")
    parser.add_argument("--judge_model", type=str, required=True, help="model to use for judgements")
    parser.add_argument("--seed", type=int, required=True, default=42, help="seed file in preprocessed")
    parser.add_argument("--followup", action="store_true", help="Include past solution in the code gen prompt")

    main(parser.parse_args())