import argparse
import random
import pandas as pd
import numpy as np
import ast

from collections import defaultdict
from tqdm import tqdm
from pathlib import Path
from pipeline.pipeline_utils import get_answers, get_question, match_lang, sample_categories, sample_user_strings
from pipeline.models import construct_client
from instructions.BaseInstruction import BaseInstruction

def get_model_family(model_name):
    """Determine which family a model belongs to based on its name."""
    for family in ["gpt", "claude", "gemini"]:
        if model_name.startswith(family):
            return family
    return "unknown"


def balanced_sample_answers(model_answers, num_samples=78):
    """
    Sample answers from model_answers with balanced representation across model families.
    
    Args:
        model_answers: Dict where keys are model names and values are dicts of {question_id: answer_data}
        num_samples: Number of samples to generate (default 78 for all questions)
    
    Returns:
        mixed_answers: Dict with same structure as individual model answers, 
                      but with balanced sampling across model families
    """
    # Get all question IDs (assuming all models have answers for the same questions)
    first_model = next(iter(model_answers.values()))
    all_question_ids = list(first_model.keys())
    
    if num_samples > len(all_question_ids):
        num_samples = len(all_question_ids)
    
    # Group models by family
    family_models = defaultdict(list)
    for model_name in model_answers.keys():
        family = get_model_family(model_name)
        family_models[family].append(model_name)
    
    print(f"Model families found: {dict(family_models)}")
    
    # Initialize sampling state
    mixed_answers = {}
    family_sample_counts = defaultdict(int)  # Track how many samples taken from each family
    
    # Sample for each question
    for question_id in all_question_ids:
        # Calculate inverse probabilities based on current sample counts
        total_samples_so_far = sum(family_sample_counts.values())
        
        if total_samples_so_far == 0:
            # First sample - equal probability for all families
            family_weights = {family: 1.0 for family in family_models.keys()}
        else:
            # Calculate inverse probabilities
            family_weights = {}
            for family in family_models.keys():
                count = family_sample_counts[family]
                # Use inverse of (count + 1) to avoid division by zero
                # Add small epsilon to ensure no family gets completely excluded
                family_weights[family] = 1.0 / (count + 1) + 0.001
        
        # Normalize weights to create probability distribution
        total_weight = sum(family_weights.values())
        family_probs = {family: weight / total_weight for family, weight in family_weights.items()}
        
        # Sample a family based on the calculated probabilities
        families = list(family_probs.keys())
        probabilities = list(family_probs.values())
        selected_family = random.choices(families, weights=probabilities)[0]
        
        # Randomly select a model from the chosen family
        selected_model = random.choice(family_models[selected_family])
        
        # Get the answer from the selected model
        mixed_answers[question_id] = model_answers[selected_model][question_id]
        mixed_answers[question_id]["source"] = selected_model
        
        # Update the sample count for this family
        family_sample_counts[selected_family] += 1
    
    print(f"Final family sample counts: {dict(family_sample_counts)}")
    print(f"Total samples generated: {len(mixed_answers)}")
    
    return mixed_answers


def main(args):
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    ans_path = Path("model_answers") / args.lang 
    model_answers = {}

    for ans_file in ans_path.iterdir():
        model = ans_file.stem
        model_answers[model] = get_answers(args.lang, model)

    # Generate balanced mixed answers using the sampling algorithm
    print("Generating balanced mixed answers...")
    mixed_answers = balanced_sample_answers(model_answers)

    questions = get_question(args.lang)
    lang = match_lang(args.lang)

    judge_client = construct_client(args.judge_model)
    BaseInstruction.set_judge_client(judge_client)


    save_path = Path("preprocessed_answers") / args.lang / f"judge_{args.judge_model}"
    save_path.mkdir(parents=True, exist_ok=True)
    save_path = save_path / f"{seed}.csv"

    if save_path.is_file():
        existing_df = pd.read_csv(save_path, index_col=0)
        # Get existing results to continue from where we left off
        output_questions = existing_df.to_dict('records')
        
        # Parse string representations back to lists for list columns
        for result in output_questions:
            for key in ['instruction_string']:
                if key in result and isinstance(result[key], str):
                    try:
                        result[key] = ast.literal_eval(result[key])
                    except (ValueError, SyntaxError):
                        # If parsing fails, keep as string
                        pass
        
        # Count how many entries each question_id has (should be k entries per question)
        question_id_counts = existing_df['id'].value_counts().to_dict()
        
        # Get existing categories for each question_id
        existing_categories = {}
        for _, row in existing_df.iterrows():
            qid = row['id']
            category = row['category']
            if qid not in existing_categories:
                existing_categories[qid] = set()
            existing_categories[qid].add(category)
        
        # Only filter out question_ids that have exactly k entries (complete)
        complete_question_ids = {qid for qid, count in question_id_counts.items() if count >= args.k}
        
        # Filter out completely processed questions
        questions = [q for q in questions if q["question_id"] not in complete_question_ids]
        
        incomplete_questions = {qid: count for qid, count in question_id_counts.items() if count < args.k}
        if incomplete_questions:
            print(f"Found {len(incomplete_questions)} incomplete question_ids with counts: {incomplete_questions}")
        
        print(f"Found {len(complete_question_ids)} completely processed question_ids (with {args.k} categories each).")
        print(f"Continuing with {len(questions)} remaining questions.")
    else:
        output_questions = []
        existing_categories = {}
        print("No existing results file found. Starting fresh.")

    progress_bar = tqdm(total=len(questions), desc="Running Experiment")
    for q in questions:
        id = q["question_id"]
        past_answer = mixed_answers[id]["answer"]

        # Set a deterministic seed for this specific question to ensure reproducibility
        # Use question_id as part of the seed to make each question's randomness independent
        question_seed = seed + hash(str(id)) % (2**31)  # Ensure it fits in int32 range
        random.seed(question_seed)
        np.random.seed(question_seed)

        # Get existing categories for this question (if any)
        existing_cats_for_question = existing_categories.get(id, set())
        categories_needed = args.k - len(existing_cats_for_question)
        
        if categories_needed <= 0:
            # This shouldn't happen since we filtered out complete questions, but just in case
            continue

        # Sample all applicable categories first (to maintain reproducibility)
        cots, all_applicable_categories = sample_categories(lang, q, past_answer, args.k)
        
        # Filter out categories that already exist for this question
        new_categories = [cat for cat in all_applicable_categories
                         if cat.instruction_id not in existing_cats_for_question]
        
        # Take only the number of categories we need
        categories_to_process = new_categories[:categories_needed]
        
        # print(f'Question {id}: existing={len(existing_cats_for_question)}, needed={categories_needed}, processing={len(categories_to_process)}')
        
        for i, category in enumerate(categories_to_process):
            instruction_strings = sample_user_strings(category)
            curr_question = {"id": id}
            curr_question["category"] = category.instruction_id
            curr_question["prev_ans"] = past_answer
            curr_question["question_str"] = q["question_str"]
            curr_question["soln_source"] = mixed_answers[id]["source"]
            curr_question["instruction_string"] = []

            for user_IF in instruction_strings:
                curr_question["instruction_string"].append(user_IF)
        
            # TODO: figure out how to get cots after the "take only the num of cats we need" step cause that might make it tricky
            curr_question["cot"] = cots[i]
            output_questions.append(curr_question)
        df = pd.DataFrame(output_questions)
        df.to_csv(save_path)
        progress_bar.update()
    
    progress_bar.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--lang", type=str, help="language we are evaluating")
    parser.add_argument("--judge_model", type=str, help="llm used to judge evals")
    parser.add_argument("--k", type=int, help="num categories per question")
    args = parser.parse_args()

    main(args)
