from pathlib import Path
from json import loads as load_json_str, load as load_json_file
from instructions import BaseInstruction, AddMoreComments
from random import choice, sample, shuffle
from instructions.BaseInstruction import Code, Language
from instructions.instructions_utils import get_categories, match_lang

    

def get_question(lang: str):
    q_path = Path("questions") / f"{lang}_question.jsonl"
    data = []
    with open(q_path, "r") as f:
        for line in f:
            if not line:
                continue

            curr_data = load_json_str(line)
            extracted = {}
            extracted["question_id"] = curr_data["question_id"]
            extracted["question_str"] = curr_data["turns"][0]
            
            # maybe add testing code later too
            data.append(extracted)
    
    return data

def get_answers(lang: str, model:str):
    ans_path = Path("model_answers") / lang / f"{model}.jsonl"
    ans = {}
    with open(ans_path, "r") as f:
        for line in f:
            if not line:
                continue

            curr_ans = load_json_str(line)
            id = curr_ans["question_id"]
            extracted = {}
            extracted["answer"] = curr_ans["choices"][0]["turns"][0]
            ans[id] = extracted
    
    return ans

def classify(categories: list[BaseInstruction]):
    c = {'semantic':[], 'structural':[], 'cosmetic':[]}
    for cat in categories:
        if "vague" in cat.tags:
            continue
        if "semantic" in cat.tags:
            c["semantic"].append(cat)
        elif "structural" in cat.tags:
            c["structural"].append(cat)
        elif "cosmetic" in cat.tags:
            c["cosmetic"].append(cat)
    return c

def sample_categories(lang: Language, question: str, answer: str, k: int):
    all_categories: list[BaseInstruction] = get_categories(lang)
    shuffle(all_categories) # iterate over random order 
    classified = classify(all_categories)
    cots = []
    applicable = []
    answer = Code(answer, lang)

    for cat_type in classified:
        cat_type_applicable = []
        cat_type_cots = []
        for cat in classified[cat_type]:
            print("checking: ", cat.instruction_id)
            cot, is_applicable = cat.check_applicability(answer)
            if is_applicable:
                cat_type_applicable.append(cat)
                cat_type_cots.append(cot)
            
            if len(cat_type_applicable) >= k//3:
                break
        applicable += cat_type_applicable
        cots += cat_type_cots

    return cots, applicable


def sample_user_strings(category: BaseInstruction) -> list[str]:
    # m = 1 for now
    if len(category.user_strings) == 0:
        instruction_string = category.description
    else:
        instruction_string = choice(category.user_strings)
    if instruction_string is None:
        raise ValueError(f"Category {category.instruction_id} should sample a user string or provide a canonical instruction")
    return [instruction_string]


def construct_question(q: str, instruction: str, past_answer: str, followup: bool):
    # TODO: possibly add trasnformation here like llm to slightly adjust pipeline to be robust to references/name

    if followup:
        question_and_answer = q + "\n\nYour previous solution was:\n```\n" + past_answer + "```\n"
        task = question_and_answer + f"Your previous code solution did not obey the instruction: {instruction}\n\nSolve this question while obeying the instruction\nYou MUST NOT modify the core logic of the previous code snippet UNLESS the instruction is related to the core logic (i.e. instructions about optimization, refactoring, etc)"
    else:
        code_block_idx = q.find("```") # if ``` not present, defaults to -1
        task = q[:code_block_idx] + "\n" + f"Your code must obey this instruction: {instruction}\n\n" + q[code_block_idx:]
    return task


def IF_eval(category: BaseInstruction, lang: Language, curr_answer: str, past_answer: str):
    past = Code(past_answer, lang)
    curr = Code(curr_answer, lang)
    return category.verify_application(past, curr)

