import json

import agent.agents as agent
import log_debate.log_debate
import config.config as config_set
from mad_prompt.prompt import GSMMADPrompt, LogicMADPrompt, SQAMADPrompt, CryptoMADPrompt
import random

def construct_message(agents, question, idx):


    prefix_string = "These are the solutions to the problem from other agent: "

    for agent in agents:
        agent_response = agent[idx]["content"]
        response = "\n\n One agent solution: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + mad_prompt.prefix_prompt.format(question)
    return {"role": "user", "content": prefix_string}


def construct_assistant_message(content_input):

    return {"role": "assistant", "content": content_input}


def read_jsonl(path: str, encoding='utf-8'):
    with open(path, encoding=encoding) as fh:
        return [json.loads(line) for line in fh.readlines() if line]


def write_matrix_to_file(matrix, question, file_path):

    with open(file_path, "a") as f:
        f.write(f"Question: {question}\n")
        f.write("Agent ID\\Round |")
        for r in range(len(matrix[0])):
            f.write(f" Round {r} |")
        f.write("\n")

        for i, row in enumerate(matrix):
            f.write(f"Agent {i} |")
            for ans in row:
                f.write(f" {ans if ans else 'None'} |")
            f.write("\n")
        f.write("\n")


def update_score(score_dict, prev_ans, current_ans, round_num):
    score_base = 1.0 / float(round_num+1)
    if round_num == 0:

        score_dict[current_ans] = score_dict.get(current_ans, 0) + 20 * score_base
    else:

        if prev_ans != current_ans:

            if prev_ans in score_dict:
                score_dict[prev_ans] -= 25 * score_base


            score_dict[current_ans] = score_dict.get(current_ans, 0) + 30 * score_base
        else:

            score_dict[current_ans] = score_dict.get(current_ans, 0) + 20 * score_base
    return score_dict


def get_final_answer(score_dict):

    if not score_dict:
        return None


    filtered_dict = {k: v for k, v in score_dict.items() if k is not None}


    if not filtered_dict:
        return None


    max_score = max(filtered_dict.values())


    max_keys = [k for k, v in filtered_dict.items() if v == max_score]

    return random.choice(max_keys)



if __name__ == "__main__":
    mad_prompt = None

    conf = config_set.load_config("./config/config.json")
    if conf:
        print("------------config load success--------------------")
        for key, value in conf.items():
            print(f"{key}: {value}")

    dataset_id = conf.get("dataset_id")
    is_nonconformity = conf.get("is_nonconformity")

    # Init prompt
    if dataset_id == "gsm":
        mad_prompt = GSMMADPrompt(is_nonconformity)
    elif dataset_id == "logic":
        mad_prompt = LogicMADPrompt(is_nonconformity)
    elif dataset_id == "sqa":
        mad_prompt = SQAMADPrompt(is_nonconformity)
    elif dataset_id == "crypto":
        mad_prompt = CryptoMADPrompt(is_nonconformity)

    rounds = conf.get("round")
    no_question = conf.get("dataset_size")
    log_path = conf.get("log_path")
    agent_models = conf.get("agent_models")
    agent_no = conf.get("agent_no")
    assert agent_no == len(agent_models), "error!"

    generated_description = {}
    matrix_file = conf.get("matrix_file")
    dataset_path = conf.get("dataset_path")
    result_path = conf.get("result_path")



    questions = read_jsonl(dataset_path)

    logger = log_debate.log_debate.Logger(log_file=log_path, log_level='INFO')


    agent_Set = []
    for i in range(agent_no):
        agent_Set.append(agent.Agent(
            model_name=agent_models[i],
            no=i,
            token=0
        ))

    correct_count = 0

    pp= 0
    for data in questions[:no_question]:
        question = data['question']
        answer = data['answer']

        logger.info("Question: " + question)
        pp = pp + 1
        print("---------------", pp)
        agent_contexts = [[{"role": "user", "content": mad_prompt.context_prompt.format(question)}] for agent in range(len(agent_Set))]

        answer_matrix = [[None] * rounds for _ in range(len(agent_Set))]

        score_dict = {}

        for round_idx in range(rounds):
            logger.info(f"=============Round {round_idx}===============")

            for i, agent_context in enumerate(agent_contexts):
                if round_idx != 0:

                    agent_contexts_other = agent_contexts[:i] + agent_contexts[i + 1:]
                    message = construct_message(agent_contexts_other, question, 2 * round_idx - 1)
                    agent_context.append(message)


                response = agent_Set[i].calling_api(agent_context, logger)
                assistant_message = construct_assistant_message(response)
                agent_context.append(assistant_message)


                current_ans = mad_prompt.parse_answer(response)


                answer_matrix[i][round_idx] = current_ans


                prev_ans = answer_matrix[i][round_idx - 1] if round_idx > 0 else None
                score_dict = update_score(score_dict, prev_ans, current_ans, round_idx)


            print(f"\nAfter Round {round_idx}:")
            print("Answer Matrix:")
            for row in answer_matrix:
                print(row)
            print("Score Dict:", score_dict)


        final_answer = get_final_answer(score_dict)
        gt = mad_prompt.solve_problems(answer)
        logger.info(f"Final Answer: {final_answer}, Standard Answer: {answer}")


        correct = mad_prompt.is_correct(final_answer, gt)
        if correct:
            correct_count += 1
        else :
            print("error answer:", question, "\n", "answer: ", final_answer, "gt: ", gt)


        write_matrix_to_file(answer_matrix, question, matrix_file)


        generated_description[question] = (agent_contexts, answer)

    accuracy = correct_count / no_question

    json.dump(generated_description, open(result_path + "/xxx_{}_{}.json".format(len(agent_Set), rounds), "w"))



    print("\n\n===== Final Results Summary =====")
    print(f"Total Questions: {no_question}")
    print(f"Correct Answers: {correct_count}")
    print(f"Accuracy: {accuracy * 100:.2f}%")

    total_tokens = 0

    with open(result_path + "/agent_tokens.txt", "w") as file:
        for idx, agent in enumerate(agent_Set):
            token_value = agent.token
            total_tokens += token_value
            file.write(f"Agent {idx + 1}: Token = {token_value}\n")

        file.write(f"\nTotal Tokens: {total_tokens}\n")

    print("The token values and total amount for all agents have been written to agent_tokens.txt.")

