import argparse
import pandas as pd
import numpy as np
import random

from pipeline.pipeline_utils import (get_question, get_answers, 
                            get_categories, sample_categories, 
                            sample_user_strings, construct_question, 
                            IF_eval, match_lang)
from pipeline.models import generate_answer, construct_client
from pathlib import Path
from json import dumps
from tqdm import tqdm
from os import environ
from instructions.BaseInstruction import BaseInstruction


def main(args):
    random.seed(42)
    np.random.seed(42)
    lang = match_lang(args.lang)
    questions = get_question(args.lang)
    original_answers = get_answers(args.lang, args.code_model)

    # repeated requests in short time frame. Does nothing for openai calls
    judge_client = construct_client(args.judge_model)
    answer_client = construct_client(args.code_model)
    
    # Initialize the singleton judge client for all instruction instances
    BaseInstruction.set_judge_client(judge_client)

    IF_results = []
    save_path = Path("results") / args.lang
    if args.followup:
        save_path = save_path / "followup"
    else:
        save_path = save_path / "predefined"

    save_path = save_path / f"judge_{args.judge_model}"
    save_path.mkdir(parents=True, exist_ok=True)
    save_path = save_path / f"{args.code_model}.csv"
    
    # Load existing results and avoid duplicates
    processed_question_ids = set()
    if save_path.is_file():
        existing_df = pd.read_csv(save_path, index_col=0)
        # Get existing results to continue from where we left off
        IF_results = existing_df.to_dict('records')
        
        # Parse string representations back to lists for list columns
        import ast
        for result in IF_results:
            for key in ['instruction_string', 'IF_answers', 'result']:
                if key in result and isinstance(result[key], str):
                    try:
                        result[key] = ast.literal_eval(result[key])
                    except (ValueError, SyntaxError):
                        # If parsing fails, keep as string
                        pass
        
        # Get question IDs that have already been processed
        processed_question_ids = set(existing_df['id'].unique())
        
        # Filter out already processed questions
        questions = [q for q in questions if q["question_id"] not in processed_question_ids]
        print(f"Found {len(processed_question_ids)} already processed questions. Continuing with {len(questions)} remaining questions.")
    else:
        IF_results = []
        print("No existing results file found. Starting fresh.")

    progress_bar = tqdm(total=len(questions), desc="Running Experiment")

    # for q in []:
    for q in questions:
        id = q["question_id"]
        past_answer = original_answers[id]["answer"]

        applicable_categories = sample_categories(lang, q, past_answer, args.k)
        # print(f'Using categories {[cat.instruction_id for cat in applicable_categories]}')
        for category in applicable_categories:
            instruction_strings = sample_user_strings(category)
            curr_results = {"id": id}
            curr_results["category"] = category.instruction_id
            curr_results["prev_ans"] = past_answer
            curr_results["instruction_string"] = []
            curr_results["IF_answers"] = []
            curr_results["result"] = []
            curr_results["verify_cot"] = []

            for user_IF in instruction_strings:
                IF_question = construct_question(q["question_str"], user_IF, past_answer, args.followup)
                curr_answer = generate_answer(args.code_model, IF_question, answer_client)


                cot, result = IF_eval(category, lang, curr_answer, past_answer)

                curr_results["instruction_string"].append(user_IF)
                curr_results["IF_answers"].append(curr_answer)
                curr_results["result"].append(result)
                curr_results["verify_cot"].append(cot)
            
            # one result per category per question (78 * k results)
            IF_results.append(curr_results)

        # Save intermediate results after each question
        df = pd.DataFrame(IF_results)
        df.to_csv(save_path)
        
        progress_bar.update()
    
    progress_bar.close()

    # Final save
    df = pd.DataFrame(IF_results)
    df.to_csv(save_path)
    
    def convert_result_to_int(result_value):
        """
        Convert result value to int, handling both string and list formats.
        
        Args:
            result_value: Either a string representing an int, or a list with 1 int entry
            
        Returns:
            int: The converted integer value
        """
        if isinstance(result_value, str):
            # Handle string representation of int
            try:
                return int(result_value)
            except ValueError:
                # If string can't be converted to int, return 0 as fallback
                print(f"Warning: Could not convert string '{result_value}' to int, using 0")
                return 0
        elif isinstance(result_value, list) and len(result_value) == 1:
            # Handle list with 1 entry
            try:
                return int(result_value[0])
            except (ValueError, TypeError):
                # If list entry can't be converted to int, return 0 as fallback
                print(f"Warning: Could not convert list entry '{result_value[0]}' to int, using 0")
                return 0
        else:
            # Handle unexpected formats
            print(f"Warning: Unexpected result format '{result_value}' (type: {type(result_value)}), using 0")
            return 0
    
    # Convert result column to integers before calculating average
    df['result_int'] = df['result'].apply(convert_result_to_int)
    df.loc[df['result_int'] == -1, 'result_int'] = 0
    
    average = df['result_int'].mean()
    print(f"Overall IF: {average}")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--lang", type=str, required=True, help="lang for the questions")
    parser.add_argument("--code_model", type=str, required=True, help="model to use for IF code generations")
    parser.add_argument("--judge_model", type=str, required=True, help="model to use for judgements")
    parser.add_argument("--k", type=int, required=True, help="number of applicable categories to sample")
    parser.add_argument("--m", type=int, required=True, default=1, help="number of user strings to sample from each category")
    parser.add_argument("--followup", action="store_true", help="Include past solution in the code gen prompt")

    main(parser.parse_args())