import os
import json
import sympy
from loguru import logger


from copy import deepcopy
from sycophancy.api import APIQuery
from sycophancy.scanner import Scanner
from sycophancy.parser import WarningType, extract_judgement
import numpy as np
from collections import Counter

def run(model_config, setting_config, config_path, skip_existing=False, output_folder="outputs", reparse=False):
    model = model_config["model"]
    n = model_config["n"]
    api = model_config["api"]


    max_tokens = model_config.get("max_tokens", setting_config["default_max_tokens"])
    temperature = model_config.get("temperature", setting_config["default_temperature"])
    kwargs = model_config.copy()
    del kwargs["model"]
    del kwargs["n"]
    del kwargs["api"]
    del kwargs["human_readable_id"]
    if 'date' in kwargs:
        del kwargs['date']
    kwargs["max_tokens"] = max_tokens
    kwargs["temperature"] = temperature
    
    logger.info(f"New run, model: {model}, competition: {setting_config['name']}")

    with open(setting_config['data_path'], 'r') as f:
        problems = json.load(f)
        if setting_config.get('n_solutions', -1) != -1:
            problems = problems[:setting_config.get('n_solutions', -1)]

    output_dir = os.path.join(output_folder, setting_config['name'])
    os.makedirs(output_dir, exist_ok=True)

    batch_prompts = []
    batch_idx_to_problem_idx = {}

    all_grades_per_problem = {i: [] for i in range(len(problems))}
    detailed_costs_per_problem = {i: [] for i in range(len(problems))}

    for i, problem in enumerate(problems):
        problem_id = problem["problem_id"].replace('/', '_')
        model_id = problem['model_id']
        output_file = os.path.join(output_dir, config_path, model_id, f"{problem_id}.json")
        if skip_existing and os.path.exists(output_file):
            data_file = json.load(open(output_file))
            existing_grades = data_file["messages"]
            
            # print all the message lengths
            if "detailed_costs" in data_file:
                detailed_costs = data_file["detailed_costs"]
            else:
                cost = data_file["cost"]
                detailed_costs = [{"cost": cost["cost"] if i == 0 else 0, 
                                    "input_tokens": cost["input_tokens"] if i == 0 else 0, 
                                    "output_tokens": cost["output_tokens"] if i == 0 else 0} 
                                    for i in range(len(existing_grades))]
                
            detailed_costs_per_problem[i] = detailed_costs
            all_grades_per_problem[i] = existing_grades
            logger.info(f"Skipping problem: {problem_id} ({len(existing_grades)} times)")

            if len(existing_grades) == n or reparse:
                calculate_problem_results(model_config, config_path, problem, output_dir, existing_grades,
                                        detailed_costs, model)
                continue

        for _ in range(n - len(all_grades_per_problem[i])):
            batch_idx_to_problem_idx[len(batch_prompts)] = i
            batch_prompts.append(problem)

    if len(batch_prompts) == 0 or reparse:
        return
    
    logger.info("Collected all queries, now running")

    with open(setting_config['prompt_path'], 'r') as f:
        prompt_template = f.read()
    api = APIQuery(
        model=model, 
        api=api,
        **kwargs
    )

    verifier = Scanner(
        querier=api,
        prompt_template = prompt_template,
        **kwargs
    )

    for idx, messages, detailed_cost in verifier.verify(batch_prompts):
        problem_idx = batch_idx_to_problem_idx[idx]
        problem = problems[problem_idx]
        all_grades_per_problem[problem_idx].append(messages)
        detailed_costs_per_problem[problem_idx].append(detailed_cost)

        # check if the whole problem is finished
        if len(all_grades_per_problem[problem_idx]) == n:
            calculate_problem_results(model_config, config_path, problem, output_dir, 
                                      all_grades_per_problem[problem_idx], 
                                      detailed_costs_per_problem[problem_idx], 
                                      model)
    if api != 'baseline':
        api.free_model()
def calculate_problem_results(model_config, config_path, problem, output_dir, messages_problem, 
                              costs_problem, verifier_id):
    problem_id = problem["problem_id"].replace('/', '_')
    model_id = problem['model_id']
    output_file = os.path.join(output_dir, config_path, model_id, f"{problem_id}.json")
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    n = len(messages_problem)
    judgements = []
    warnings = []
    corrects = []
    for j in range(n):
        model_answer = messages_problem[j][-1]['content']
        if isinstance(model_answer, list):
            model_answer = model_answer[-1]
            if isinstance(model_answer, dict) and 'content' in model_answer:
                model_answer = model_answer['content']
            model_answer = str(model_answer)

        try:
            model_judgement, warning = extract_judgement(model_answer)
        except:
            model_judgement = "incorrect"
            warning = WarningType.MAJOR
        is_correct = model_judgement != "incorrect"
        
        if len(model_answer) == 0:
            logger.warning(f"Empty message in problem: {problem_id}, idx: {j}")
            warning = WarningType.MAJOR
        judgements.append(model_judgement)
        warnings.append(warning.value)
        corrects.append(is_correct)

    try:
        logger.info(f"Finished problem {problem_id}, solved by {model_id} - judgements: {judgements}, #Correct: {sum(corrects)}")
    except:
        pass

    pass_at_1 = sum(corrects)/n
    
    majority_vote= Counter(judgements).most_common(1)[0][0]


    for i in range(len(costs_problem)):
        costs_problem[i]["cost"] = model_config["read_cost"] * costs_problem[i]["input_tokens"] + model_config["write_cost"] * costs_problem[i]["output_tokens"]
        costs_problem[i]["cost"] /= 10 ** 6
    cost = {
        "cost": sum([d["cost"] for d in costs_problem]),
        "input_tokens": sum([d["input_tokens"] for d in costs_problem]),
        "output_tokens": sum([d["output_tokens"] for d in costs_problem]),
    }

    with open(output_file, "w") as f:
        json.dump({
                    "problem": problem['problem'],
                    "problem_id": problem_id,
                    "solutions": problem['solutions'],
                    "full_solution": problem.get('full_solution', ''),
                    "thinking": problem.get('thinking', ''),
                    "solution": problem.get('solution', ''),
                    "grading_scheme": problem['grading_scheme'],
                    "model_id": model_id,
                    "cost": cost,
                    "judgements": judgements,
                    "correct": corrects,
                    "pass_at_1": pass_at_1,
                    "majority_vote": majority_vote,
                    "cost": cost,
                    "solution_cost": problem.get('cost', {}).get('cost', 0),
                    "detailed_costs": costs_problem,
                    "warnings": warnings,
                    "messages": messages_problem,
                }, f)

def convert_answer(answer):
    try:
        if type(answer) == sympy.Integer:
            return int(answer)
        else:
            return str(answer)
    except:
        return "None"