import json
import numpy as np
import random
import os
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from util import _strip_string
from clustering import KmeansClustering
import math
import heapq
from parser import *


def construct_message(current_agent, agents, question, idx, all_probs, cluster_result=None, conflict_scores=None, uncertainties=None, T=None, top_k=4):
    prefix_string = "These are the solutions to the problem from other agents: "
    cluster_i = []
    cluster_idx = -1
    for idx0, cluster in enumerate(cluster_result): 
        if current_agent in cluster: 
            cluster_i = cluster
            cluster_idx = idx0
    cluster_i_probs = [all_probs[i] for i in cluster_i]
    cluster_i_agents = [agents[i] for i in cluster_i]
    if T == 0: 
        if uncertainties[cluster_idx] == max(uncertainties):
            if all_probs[current_agent] == min(cluster_i_probs):
                for c in range(len(conflict_scores[cluster_idx])): 
                    if conflict_scores[cluster_idx][c] > 2:
                        for agent_idx, agent in enumerate(agents): 
                            if agent_idx not in cluster_result[c]: 
                                continue
                            if all_probs[agent_idx] != max([all_probs[j] for j in cluster_result[c]]): 
                                continue
                            agent_response = agent[idx]["content"]
                            response = "\n\n One conflicting agent solution: ```{}```".format(agent_response)
                            prefix_string = prefix_string + response
            else: 
                for c in range(len(conflict_scores[cluster_idx])): 
                    if 0 <= conflict_scores[cluster_idx][c] <= 2:
                        for agent_idx, agent in enumerate(agents): 
                            if agent_idx not in cluster_result[c]: 
                                continue
                            if all_probs[agent_idx] != max([all_probs[j] for j in cluster_result[c]]): 
                                continue
                            agent_response = agent[idx]["content"]
                            response = "\n\n One supporting agent solution: ```{}```".format(agent_response)
                            prefix_string = prefix_string + response
        else: 
            cluster_max_unc_idx = uncertainties.index(max(uncertainties))
            cluster_max_unc_probs = [all_probs[i] for i in cluster_result[cluster_max_unc_idx]]
            cluster_max_unc_agents = [agents[i] for i in cluster_result[cluster_max_unc_idx]]
            min_probs = cluster_max_unc_probs.index(min(cluster_max_unc_probs))
            for c in range(len(conflict_scores[cluster_idx])): 
                if 0 <= conflict_scores[cluster_idx][c] <= 2:
                    for agent_idx, agent in enumerate(agents):
                        if agent_idx not in cluster_result[c]: 
                            continue
                        if c == cluster_max_unc_idx and agent_idx == cluster_result[cluster_max_unc_idx][min_probs]:
                            continue
                        if all_probs[agent_idx] != max([all_probs[j] for j in cluster_result[c]]): 
                            continue
                        agent_response = agent[idx]["content"]
                        response = "\n\n One supporting agent solution: ```{}```".format(agent_response)
                        prefix_string = prefix_string + response
        
        prefix_string = prefix_string + """\n\n Selecting and using the trustable solutions from current collaboration as additional information, can you provide your answer to the problem? \n {}. """.format(question)

    else: 
        agent_idxs = find_top_indices(cluster_i_probs, top_k)
        for agent_idx in agent_idxs: 
            agent = cluster_i_agents[agent_idx]
            agent_response = agent[idx]["content"]
            response = "\n\n One leader solution: ```{}```".format(agent_response)
            prefix_string = prefix_string + response
        prefix_string = prefix_string + """\n\n Selecting and using the leading solutions from current collaboration as additional information, can you provide your answer to the problem? \n {}. """.format(question)
    prefix_string = "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"+"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n".format(prefix_string)
    return {"role": "user", "content": prefix_string}


def construct_assistant_message(completion):
    content = completion[0].outputs[0].text
    return {"role": "assistant", "content": content}

def find_top_indices(lst, top_k):
    heap = [(num, i) for i, num in enumerate(lst[:top_k])]
    heapq.heapify(heap)
    for i, num in enumerate(lst[top_k:], top_k):
        if num > heap[0][0]:
            heapq.heapreplace(heap, (num, i))
    return [idx for _, idx in heap]

def solve_math_problems(input_str):
    match = extract_answer(input_str, 'math')
    if match == '': 
        return 'None'
    return match

def list_mean(lst): 
    if len(lst) == 0: 
        return 0
    else: 
        return np.mean(lst)

def estimate_conflict(cluster_result, all_probs, support_results): 
    conflict_scores = [[0 for i in range(len(cluster_result))] for j in range(len(cluster_result))]
    for i in range(len(cluster_result)): 
        for j in range(len(cluster_result)): 
            if i <= j: 
                continue
            set_A = set([support_results[ti][:-1] if support_results[ti][-1]=='.' else support_results[ti] for ti in cluster_result[i]])
            set_B = set([support_results[tj][:-1] if support_results[tj][-1]=='.' else support_results[tj] for tj in cluster_result[j]])
            set_A_i = []
            set_A_u = []
            set_B_i = []
            set_B_u = []
            A_i_B, A_u_B = set_A & set_B, set_A | set_B
            for ci in cluster_result[i]: 
                if support_results[ci] in A_u_B - A_i_B: 
                    set_A_u.append(all_probs[ci])
                if support_results[ci] in A_i_B: 
                    set_A_i.append(all_probs[ci])
            for cj in cluster_result[j]: 
                if support_results[cj] in A_u_B - A_i_B: 
                    set_B_u.append(all_probs[cj])
                if support_results[cj] in A_i_B: 
                    set_B_i.append(all_probs[cj])
            if len(A_i_B)!= 0 and len(set_A_u+set_B_u) > 0: 
                conflict_score = (list_mean(set_A_u+set_B_u)/list_mean(set_A_u+set_B_u+set_A_i+set_B_i)) * (abs(list_mean(set_A_i)-list_mean(set_B_i)) / abs(list_mean(set_A_u)-list_mean(set_B_u)))
            elif len(A_i_B)!= 0 and len(set_A_u+set_B_u) == 0:
                conflict_score = abs(list_mean(set_A_i)-list_mean(set_B_i))
            else: 
                conflict_score = (list_mean(set_A_u+set_B_u)/list_mean(set_A_u+set_B_u+set_A_i+set_B_i)) / abs(list_mean(set_A_u)-list_mean(set_B_u)) 
            conflict_scores[j][i] = conflict_scores[i][j] = conflict_score
    return conflict_scores

def calculate_entropy(lst): 
    entropy = 0
    for a in lst: 
        try: 
            entropy += -a*math.log(a)
        except: 
            entropy += 0
    return entropy / len(lst)

def estimate_uncertainty(cluster_result, all_probs, support_results): 
    probs = [[all_probs[i] for i in j] for j in cluster_result]
    uncs = [calculate_entropy(l) for l in probs]
    return uncs


def estimate_consensus(cluster_result, all_probs, support_results): 
    for idx, sr in enumerate(support_results): 
        if sr == '': 
            sr = 'None'
        if sr[-1] == '.': 
            support_results[idx] = sr[:-1]
    most_frequent_ele = max(support_results, key=support_results.count)
    frequency = support_results.count(most_frequent_ele)
    score = 0
    total_score = 0
    for sr, prob in zip(support_results, all_probs): 
        if sr == most_frequent_ele: 
            score += prob
        total_score += prob
    try: 
        consensus_score = score / total_score
    except: 
        consensus_score = 1
    if (frequency/len(support_results)) > (2/3) and consensus_score > 0.8: 
        return 'full consensus'
    elif (frequency/len(support_results)) >= (2/len(support_results)) and consensus_score > 0.5: 
        return 'partial consensus'
    else: 
        return 'no consensus'

def parse_question_answer(subdir, file):
    
    def find_math_answer(s):
        assert('boxed' in s)
        ans = s.split('boxed')[-1]
        if(ans[0] == '{'):
            stack = 1
            a = ''
            for c in ans[1:]:
                if(c == '{'):
                    stack += 1
                    a += c
                elif(c == '}'):
                    stack -= 1
                    if(stack == 0): break
                    a += c
                else:
                    a += c
        else:
            a = ans.split('$')[0].strip()
        a=_strip_string(a)
        return a

    with open(os.path.join(subdir, file), 'r') as fp:
        try:
            problem_data = json.load(fp)
        except Exception as e:
            raise e
        prob_content = problem_data["problem"]
        prob_level = problem_data["level"]
        prob_type = problem_data["type"]
        answer = problem_data["answer"]
        try:
            prob_level = int(prob_level.split("Level ")[1])
        except:
            prob_level = None
        return prob_content, prob_level, prob_type, problem_data['solution'], answer

if __name__ == "__main__": 
    model_path = "Qwen/Qwen2.5-7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    sampling_params = SamplingParams(temperature=0.7, logprobs=True)
    llm = LLM(model=model_path)
    SYSTEM_PROMPT = "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
    agents = 7
    rounds = 3
    top_k = 2
    random.seed(0)
    generated_description = []
    DIR = 'MATH DATASET'
    TARGET_DIR = "MATH OUTPUT"
    SUB_DIR = os.listdir(DIR)
    for sub_dir in SUB_DIR: 
        files = os.listdir(os.path.join(DIR, sub_dir))
        for file in files:
            question, prob_level, prob_type, solution, answer = parse_question_answer(os.path.join(DIR, sub_dir), file)
            file_res = prob_type.lower().replace(' ', '_').strip() + '-' + file
            question_ = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n".format(question)
            prompt = "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n".format(question) + "<|im_start|>assistant\n"
            agent_contexts = [[{"role": "user", "content": prompt}] for _ in range(agents)]
            T = 0
            for round in range(rounds):
                documents = []
                support_results = []
                for i, agent_context in enumerate(agent_contexts):
                    if round != 0:
                        agent_contexts_other = agent_contexts
                        message = construct_message(i, agent_contexts_other, question, 2*round - 1, all_probs, cluster_result, conflict_scores, uncs, T=T, top_k=top_k)
                        agent_context.append(message)
                    text = tokenizer.apply_chat_template(
                        agent_context,
                        tokenize=False,
                        add_generation_prompt=True
                    )
                    completion = llm.generate([text], sampling_params)
                    documents.append(completion[0].outputs[0].logprobs)
                    assistant_message = construct_assistant_message(completion)
                    agent_context.append(assistant_message)
                    support_results.append(solve_math_problems(assistant_message['content']))
                T = 0
                Kmeans = KmeansClustering(stopwords_path='en_stop_words.txt')
                cluster_result, all_probs = Kmeans.kmeans(documents, n_clusters=3)
                cluster_result = [cluster_result[key] for key in cluster_result.keys()]
                consensus_score = estimate_consensus(cluster_result, all_probs, support_results)
                if consensus_score == 'full consensus': 
                    break
                elif consensus_score == 'partial consensus': 
                    pass
                elif consensus_score == 'no consensus': 
                    T = 1
                conflict_scores = estimate_conflict(cluster_result, all_probs, support_results)
                uncs = estimate_uncertainty(cluster_result, all_probs, support_results)
            res_dict = {'problem': question, 'level': prob_level, 'type': prob_type, 'solution': solution, 'answer': answer, 'agent_contexts': agent_contexts}
            json.dump(res_dict, open(os.path.join(TARGET_DIR, file_res), "w", encoding='utf-8'), indent=4)
