from glob import glob
import pandas as pd
import json
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from clustering import KmeansClustering
import re
import math
import numpy as np
import heapq
import os

def construct_message(current_agent, agents, question, idx, cluster_result=None, conflict_scores=None, uncertainties=None, T=None, top_k=4):
    if len(agents) == 0:
        return {"role": "user", "content": "Can you double check that your answer is correct. Put your final answer in the form (answer) at the end of your response. (answer) represents choice (A), (B), (C), or (D)."}

    cluster_i = []
    cluster_idx = -1
    for idx0, cluster in enumerate(cluster_result): 
        if current_agent in cluster: 
            cluster_i = cluster
            cluster_idx = idx0
    cluster_i_probs = [all_probs[i] for i in cluster_i]
    cluster_i_agents = [agents[i] for i in cluster_i]
    prefix_string = "Here is the question:\n" + question + "\n\nThese are the solutions to the problem from other agents: "
    if T == 0: 
        if uncertainties[cluster_idx] == max(uncertainties):
            if all_probs[current_agent] == min(cluster_i_probs): 
                for c in range(len(conflict_scores[cluster_idx])): 
                    if conflict_scores[cluster_idx][c] > 2:
                        for agent_idx, agent in enumerate(agents):
                            if agent_idx not in cluster_result[c]: 
                                continue
                            if all_probs[agent_idx] != max([all_probs[j] for j in cluster_result[c]]): 
                                continue
                            agent_response = agent[idx]["content"]
                            response = "\n\n One conflicting agent solution: ```{}```".format(agent_response)
                            prefix_string = prefix_string + response
            else: 
                for c in range(len(conflict_scores[cluster_idx])): 
                    if 0 <= conflict_scores[cluster_idx][c] <= 2: 
                        for agent_idx, agent in enumerate(agents):
                            if agent_idx not in cluster_result[c]: 
                                continue
                            if all_probs[agent_idx] != max([all_probs[j] for j in cluster_result[c]]): 
                                continue
                            agent_response = agent[idx]["content"]
                            response = "\n\n One supporting agent solution: ```{}```".format(agent_response)
                            prefix_string = prefix_string + response
        else: 
            
            cluster_max_unc_idx = uncertainties.index(max(uncertainties))
            cluster_max_unc_probs = [all_probs[i] for i in cluster_result[cluster_max_unc_idx]]
            min_probs = cluster_max_unc_probs.index(min(cluster_max_unc_probs))

            for c in range(len(conflict_scores[cluster_idx])): 
                if 0 <= conflict_scores[cluster_idx][c] <= 2: 
                    for agent_idx, agent in enumerate(agents):
                        if agent_idx not in cluster_result[c]: 
                            continue
                        if c == cluster_max_unc_idx and agent_idx == cluster_result[cluster_max_unc_idx][min_probs]:
                            continue
                        if all_probs[agent_idx] != max([all_probs[j] for j in cluster_result[c]]): 
                            continue
                        agent_response = agent[idx]["content"]
                        response = "\n\n One supporting agent solution: ```{}```".format(agent_response)
                        prefix_string = prefix_string + response
        prefix_string = prefix_string + """\n\n Judging which solutions are trustable and using the solutions from other agents as additional advice, can you give an updated answer? Examine your solution and that other agents step by step. Put your answer in the form (answer) at the end of your response. (answer) represents choice (A), (B), (C), or (D).""".format(question)
    else: 
        agent_idxs = find_top_indices(cluster_i_probs, top_k)
        for agent_idx in agent_idxs: 
            agent = cluster_i_agents[agent_idx]
            agent_response = agent[idx]["content"]
            response = "\n\n One leader solution: ```{}```".format(agent_response)
            prefix_string = prefix_string + response
        prefix_string = prefix_string + """\n\n Judging which solutions can lead the trend of thought and using the solutions from other agents as additional advice, can you give an updated answer? Examine your solution and that other agents step by step. Put your answer in the form (answer) at the end of your response. (answer) represents choice (A), (B), (C), or (D).""".format(question)
    return {"role": "user", "content": prefix_string}


def construct_assistant_message(completion):
    content = completion[0].outputs[0].text
    return {"role": "assistant", "content": content}


def generate_answer(agent_context):
    text = tokenizer.apply_chat_template(
        agent_context,
        tokenize=False,
        add_generation_prompt=True
    )
    completion = llm.generate([text], sampling_params)

    return completion

def list_mean(lst): 
    if len(lst) == 0: 
        return 0
    else: 
        return np.mean(lst)

def estimate_consensus(cluster_result, all_probs, support_results): 
    results_c = [i for i in support_results if (i is not None and i in 'ABCDabcd')]
    if results_c == []: 
        most_frequent_ele = None
    else: 
        most_frequent_ele = max(results_c, key=results_c.count)

    frequency = support_results.count(most_frequent_ele)
    score = 0
    total_score = 0
    for sr, prob in zip(support_results, all_probs): 
        if sr == most_frequent_ele: 
            score += prob
        total_score += prob
    consensus_score = score / total_score
    if (frequency/len(support_results)) > (2/3) and consensus_score > 0.8: 
        return 'full consensus'
    elif (frequency/len(support_results)) >= (2/len(support_results)) and consensus_score > 0.5: 
        return 'partial consensus'
    else: 
        return 'no consensus'

def estimate_conflict(cluster_result, all_probs, support_results): 
    conflict_scores = [[0 for i in range(len(cluster_result))] for j in range(len(cluster_result))]
    for i in range(len(cluster_result)): 
        for j in range(len(cluster_result)): 
            if i <= j: 
                continue
            set_A = set([support_results[ti] for ti in cluster_result[i]])
            set_B = set([support_results[tj] for tj in cluster_result[j]])
            set_A_i = []
            set_A_u = []
            set_B_i = []
            set_B_u = []
            A_i_B, A_u_B = set_A & set_B, set_A | set_B
            for ci in cluster_result[i]: 
                if support_results[ci] in A_u_B - A_i_B: 
                    set_A_u.append(all_probs[ci])
                if support_results[ci] in A_i_B: 
                    set_A_i.append(all_probs[ci])
            for cj in cluster_result[j]: 
                if support_results[cj] in A_u_B - A_i_B: 
                    set_B_u.append(all_probs[cj])
                if support_results[cj] in A_i_B: 
                    set_B_i.append(all_probs[cj])
            if len(A_i_B)!= 0 and len(set_A_u+set_B_u) > 0: 
                conflict_score = (list_mean(set_A_u+set_B_u)/list_mean(set_A_u+set_B_u+set_A_i+set_B_i)) * (abs(list_mean(set_A_i)-list_mean(set_B_i)) / abs(list_mean(set_A_u)-list_mean(set_B_u)))
            elif len(A_i_B)!= 0 and len(set_A_u+set_B_u) == 0:
                conflict_score = abs(list_mean(set_A_i)-list_mean(set_B_i))
            else: 
                conflict_score = (list_mean(set_A_u+set_B_u)/list_mean(set_A_u+set_B_u+set_A_i+set_B_i)) / abs(list_mean(set_A_u)-list_mean(set_B_u)) 
            conflict_scores[j][i] = conflict_scores[i][j] = conflict_score
    return conflict_scores

def calculate_entropy(lst): 
    entropy = 0
    for a in lst: 
        entropy += -a*math.log(a)
    return entropy / len(lst)

def estimate_uncertainty(cluster_result, all_probs, support_results): 
    probs = [[all_probs[i] for i in j] for j in cluster_result]
    uncs = [calculate_entropy(l) for l in probs]
    return uncs

def parse_answer(input_str):
    pattern = r'\((\w)\)'
    matches = re.findall(pattern, input_str)

    solution = None

    for match_str in matches[::-1]: 
        solution = match_str.upper()
        if solution not in 'ABCD': 
            continue
        if solution:
            break
    return solution


def parse_question_answer(df, ix):
    question = df.iloc[ix, 0]
    a = df.iloc[ix, 1]
    b = df.iloc[ix, 2]
    c = df.iloc[ix, 3]
    d = df.iloc[ix, 4]

    question = "Question: {}: A) {}, B) {}, C) {}, D) {} Explain your answer, putting the answer in the form (answer) at the end of your response. (answer) represents choice (A), (B), (C), or (D).".format(question, a, b, c, d)

    answer = df.iloc[ix, 5]

    return question, answer

def find_top_indices(lst, top_k):
    heap = [(num, i) for i, num in enumerate(lst[:top_k])]
    heapq.heapify(heap)
    for i, num in enumerate(lst[top_k:], top_k):
        if num > heap[0][0]:
            heapq.heapreplace(heap, (num, i))
    return [idx for _, idx in heap]

if __name__ == "__main__":
    model_path = "Qwen/Qwen2.5-7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    sampling_params = SamplingParams(temperature=0.7, logprobs=True)
    llm = LLM(model=model_path)
    SYSTEM_PROMPT = "Please reason step by step, and answer the question."
    agents = 7
    rounds = 3
    top_k = 2
    tasks = glob("MMLU DATASET")
    TARGET_DIR = "MMLU OUTPUT"
    dfs = [pd.read_csv(task) for task in tasks]
    tasks_names = [i.split('/')[-1].split('_test')[0] for i in tasks]
    assert len(dfs) == len(tasks_names)
    for tt, (df, task_name) in enumerate(zip(dfs, tasks_names)): 
        for idx in range(len(df)): 
            question, answer = parse_question_answer(df, idx)
            agent_contexts = [[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": question}] for agent in range(agents)]
            T = 0
            for round in range(rounds):
                documents = []
                support_results = []
                for i, agent_context in enumerate(agent_contexts):
                    if round != 0:
                        agent_contexts_other = agent_contexts
                        message = construct_message(i, agent_contexts_other, question, 2*round, cluster_result, conflict_scores, uncs, T=T, top_k=top_k)
                        agent_context.append(message)
                    completion = generate_answer(agent_context)
                    documents.append(completion[0].outputs[0].logprobs)
                    assistant_message = construct_assistant_message(completion)
                    agent_context.append(assistant_message)
                    support_results.append(parse_answer(assistant_message['content']))
                T = 0
                Kmeans = KmeansClustering(stopwords_path='en_stop_words.txt')
                cluster_result, all_probs = Kmeans.kmeans(documents, n_clusters=3)
                cluster_result = [cluster_result[key] for key in cluster_result.keys()]
                consensus_score = estimate_consensus(cluster_result, all_probs, support_results)
                if consensus_score == 'full consensus': 
                    break
                elif consensus_score == 'partial consensus': 
                    pass
                elif consensus_score == 'no consensus': 
                    T = 1
                conflict_scores = estimate_conflict(cluster_result, all_probs, support_results)
                uncs = estimate_uncertainty(cluster_result, all_probs, support_results)
            response_dict = {}
            response_dict[task_name+'-'+str(idx)] = {question: (agent_contexts, answer), 'task': task_name}
            json.dump(response_dict, open(os.path.join(TARGET_DIR, task_name+'-'+str(idx)+'.json'), "w"), indent=4)
