import argparse
import pandas as pd
import numpy as np
import random
import ast

from pipeline.pipeline_utils import (get_question, get_answers, 
                            sample_categories, 
                            sample_user_strings, construct_question, 
                            IF_eval, match_lang)
from pipeline.models import generate_answer, construct_client
from pathlib import Path
from json import dumps
from tqdm import tqdm
from os import environ
from instructions.BaseInstruction import BaseInstruction
from instructions.instructions_utils import get_categories

def main(args):
    lang = match_lang(args.lang)
    input_path = Path("baseline_answers") / args.lang / ("followup" if args.followup else "predefined") / f"judge_{args.input_judge_model}" / f"{args.code_model}.csv"
    print("Using input from: ", input_path)
    answers = pd.read_csv(input_path)

    # update based on the column name, this is what it becomes in baseline
    answers["example.instruction_string"] = answers["example.instruction_string"].apply(ast.literal_eval)

    # repeated requests in short time frame. Does nothing for openai calls
    judge_client = construct_client(args.curr_judge_model)
    
    # Initialize the singleton judge client for all instruction instances
    BaseInstruction.set_judge_client(judge_client)

    all_categories = get_categories(lang)

    IF_results = []
    save_path = Path("results_judgements") / args.lang
    if args.followup:
        save_path = save_path / "followup"
    else:
        save_path = save_path / "predefined"

    save_path = save_path / f"judge_{args.curr_judge_model}"
    save_path.mkdir(parents=True, exist_ok=True)
    save_path = save_path / f"{args.code_model}.csv"
    
    print("Outputting to: ", save_path)
    # Load existing results and avoid duplicates
    processed_id_category_pairs = set()
    if save_path.is_file():
        existing_df = pd.read_csv(save_path, index_col=0)
        # Get existing results to continue from where we left off
        IF_results = existing_df.to_dict('records')
        
        # Parse string representations back to lists for list columns
        for result in IF_results:
            for key in ['instruction_string', 'IF_answers', 'result']:
                if key in result and isinstance(result[key], str):
                    try:
                        result[key] = ast.literal_eval(result[key])
                    except (ValueError, SyntaxError):
                        # If parsing fails, keep as string
                        pass
        
        # Get (id, category) pairs that have already been processed
        processed_id_category_pairs = set(zip(existing_df['id'], existing_df['category']))
        
        # Filter out already processed (id, category) combinations
        answers = answers[~answers.apply(lambda row: (row['example.id'], row['example.category']) in processed_id_category_pairs, axis=1)]
        print(f"Found {len(processed_id_category_pairs)} already processed (id, category) combinations. Continuing with {len(answers)} remaining questions.")
    else:
        IF_results = []
        print("No existing results file found. Starting fresh.")

    progress_bar = tqdm(total=len(answers), desc="Running Experiment")

    # one row in questions per qid x category combination
    for i, row in answers.iterrows():
        # update these col names to match your dataset, this is the format for baseline
        curr_results = {"id": row["example.id"]}
        curr_results["category"] = row["example.category"]
        curr_results["prev_ans"] = row["example.prev_ans"]
        curr_results["instruction_string"] = row["example.instruction_string"]
        curr_results["IF_answers"] = [row["outputs.get_raw_response"]]
        curr_results["result"] = []
        curr_results["verify_cot"] = []
       # print(f'Using categories {[cat.instruction_id for cat in applicable_categories]}')
        for answer in curr_results["IF_answers"]:
            curr_cat = None
            for category in all_categories:
                if category.instruction_id == row["example.category"]:
                    curr_cat = category
            if not curr_cat:
                raise ValueError(f"No matching category for category {row['example.category']}")

            cot, result = IF_eval(curr_cat, lang, answer, row["example.prev_ans"])
            curr_results["verify_cot"].append(cot)
            curr_results["result"].append(result)
            
        # one result per category per question (78 * k results)
        IF_results.append(curr_results)

        # Save intermediate results after each question
        df = pd.DataFrame(IF_results)
        df.to_csv(save_path)
        
        progress_bar.update()
    
    progress_bar.close()

    # Final save
    df = pd.DataFrame(IF_results)
    df.to_csv(save_path)
    
    def convert_result_to_int(result_value):
        """
        Convert result value to int, handling both string and list formats.
        
        Args:
            result_value: Either a string representing an int, or a list with 1 int entry
            
        Returns:
            int: The converted integer value
        """
        if isinstance(result_value, str):
            # Handle string representation of int
            try:
                return int(result_value)
            except ValueError:
                # If string can't be converted to int, return 0 as fallback
                print(f"Warning: Could not convert string '{result_value}' to int, using 0")
                return 0
        elif isinstance(result_value, list) and len(result_value) == 1:
            # Handle list with 1 entry
            try:
                return int(result_value[0])
            except (ValueError, TypeError):
                # If list entry can't be converted to int, return 0 as fallback
                print(f"Warning: Could not convert list entry '{result_value[0]}' to int, using 0")
                return 0
        else:
            # Handle unexpected formats
            print(f"Warning: Unexpected result format '{result_value}' (type: {type(result_value)}), using 0")
            return 0
    
    # Convert result column to integers before calculating average
    df['result_int'] = df['result'].apply(convert_result_to_int)
    df.loc[df['result_int'] == -1, 'result_int'] = 0
    
    average = df['result_int'].mean()
    print(f"Overall IF: {average}")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--lang", type=str, required=True, help="lang for the questions")
    parser.add_argument("--code_model", type=str, required=True, help="model to use for IF code generations")
    parser.add_argument("--input_judge_model", type=str, required=True, help="judge model for applicability checks (where the data is from)")
    parser.add_argument("--curr_judge_model", type=str, required=True, help="model used for judgements in this current run (can be same as input_judge model)")
    parser.add_argument("--followup", action="store_true", help="Include past solution in the code gen prompt")

    main(parser.parse_args())