import json
import os
import pickle
from typing import List

from src.entity.problems.Problem import Problem
from root import root
from tqdm import tqdm

from config import config


prompt = config["SCHEMA_PROMPT"]

Response = config["SCHEMA_RESPONSE_CLASS"]
sample_problem = config["SCHEMA_SAMPLE_QUESTION"]
sample_response = config["SCHEMA_SAMPLE_RESPONSE"]


def generate_mental_representation(problem_string: str, model):
    messages = [
            {"role": "system", "content": prompt},
            {"role": "user", "content": sample_problem},
            {"role": "assistant", "content": sample_response.model_dump_json()},
            {"role": "user", "content": problem_string}
        ]


    generated_passage = model.interact(
        messages=messages,
        json_format=Response,
        temperature=0.0,
        max_tokens=4096
    )
    # Since model_validate_json does not take dict, we have to convert it to str first
    generated_passage = json.dumps(generated_passage)


    # parse the generated passage
    parsed_passage = Response.model_validate_json(generated_passage)

    return parsed_passage.knowledge_schema.model_dump(), parsed_passage.summary

def GenerateMentalRepresentationController(problems: List[Problem], memory_type, memory_model, including_answer, embedder):
    schema_template_name = Response.__module__.split(".")[-1]
    mental_representations_path = (f'{root}/mental_representations/{schema_template_name}/{memory_type}_by_'
                                   f'{memory_model.__class__.__name__}_including_answer'
                                   f'_{including_answer}')

    # try loading the existing mental representation
    if memory_type == "semantic":
        for i in tqdm(range(len(problems)), desc="Handling embedding ☘️..."):
            problem = problems[i]
            # check if the mental representation file exists
            if os.path.exists(f'{mental_representations_path}/{problem.id}.json'):
                # load the mental representation
                with open(f'{mental_representations_path}/{problem.id}.json', "r") as file:
                    mental_representation = json.load(file)
            else:
                os.makedirs(mental_representations_path, exist_ok=True)
                if including_answer:
                    problem_string = """{past_knowledge} 
Answer to this question is: {previous_problem_answer}""".format(past_knowledge=str(problem),
                                                                    previous_problem_answer=problem.get_ground_truth())
                else:
                    problem_string = str(problem)
                # create the mental representation
                schema, summary = generate_mental_representation(problem_string, memory_model) # note that schema is a
                # dictionary
                mental_representation = {
                    "knowledge_schema": schema,
                    "summary": summary
                }
                # save the mental representation
                with open(f'{mental_representations_path}/{problem.id}.json', "w") as file:
                    json.dump(mental_representation, file)


            formatted_schema = "\n".join(f"##### {key}:\n{value}\n" for key, value in
                                         mental_representation["knowledge_schema"].items())
            problem.mental_representation = """#### Schema:
{schema}
#### Summary:
{summary}""".format(schema=formatted_schema, summary=mental_representation["summary"])


                # NOW we have mental representation, time to make the mental representation embedding!!!
            # mental_representation_embedding = embedder.encode(mental_representation)
            embedding_path = f'{mental_representations_path}/embeddings'
            # check if the embedding file exists
            if os.path.exists(f'{embedding_path}/{problem.id}.pkl'):
                # load the embedding
                with open(f'{embedding_path}/{problem.id}.pkl', "rb") as file:
                    embedding = pickle.load(file)
                    problem.add_embedding(embedding)
            else:
                # create the embedding
                embedding = embedder.encode(problem.mental_representation)
                # save the embedding
                os.makedirs(embedding_path, exist_ok=True)
                with open(f'{embedding_path}/{problem.id}.pkl', "wb") as file:
                    pickle.dump(embedding, file)
                problem.add_embedding(embedding)
        return problems

            
        


    elif memory_type == "episode":
        return None, None
