import argparse
import json
import re, os
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
import numpy as np
import jsonlines
from tqdm import trange
import random
import torch
import matplotlib.pyplot as plt
import multiprocessing
import time
import signal, functools
import metrics
from compute_perp import *
from compute_sc import *

MAX_INT = sys.maxsize
INVALID_ANS = "[Invalid]"
PROCESS = 60

def pc_evaluator(predicts, completions, perplexities, answer, equal_func, K=MAX_INT):
    m = min(len(predicts), K)
    dsu = DSU(m)

    for i in range(m):
        completion_i = completions[i]
        logprob_i = perplexities[i]
        dsu.attr[i][completion_i] = set([np.exp(logprob_i)])

    # Merge answer for self-consistency
    for i in range(m):
        if dsu.get_father(i) != i: continue
        for j in range(i):
            ans_i = predicts[i]
            ans_j = predicts[j]
            completion_i = completions[i]
            completion_j = completions[j]
            if equal_func(ans_i, ans_j, completion_i, completion_j): dsu.merge(i, j)
    
    # Compute majority votes
    max_prob, max_prob_count = 0, 0
    for i in range(m):
        if dsu.get_father(i) != i: continue
        prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
        if prob_i > max_prob: 
            max_prob = prob_i
            max_prob_count = 0
        if prob_i >= max_prob: max_prob_count += 1
    
    # Compute accuracy
    correct, answers = 0, []
    for i in range(m):
        if dsu.get_father(i) != i: continue
        ans_i = predicts[i]
        prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
        answers.append([ans_i, prob_i, check_equal(ans_i, answer)])
        if prob_i < max_prob: continue
        if check_equal(ans_i, answer): correct += 1.0 / max_prob_count
    
    # Normalize probabilities
    sum_proba = np.sum([x[1] for x in answers])
    for i in range(len(answers)):
        answers[i][1] /= sum_proba

    return correct, answers

def umpc_evaluator(predicts, completions, perplexities, answer, equal_func, K=MAX_INT):
    m = min(len(predicts), K)
    dsu = DSU(m)

    mean_proba = np.mean([np.exp(perplexities[i]) for i in range(m)])
    for i in range(m):
        completion_i = completions[i]
        logprob_i = perplexities[i]
        proba_i = np.exp(logprob_i)
        if proba_i < mean_proba: proba_i = 0
        dsu.attr[i][completion_i] = set([proba_i])

    # Merge answer for self-consistency
    for i in range(m):
        if dsu.get_father(i) != i: continue
        for j in range(i):
            ans_i = predicts[i]
            ans_j = predicts[j]
            completion_i = completions[i]
            completion_j = completions[j]
            if equal_func(ans_i, ans_j, completion_i, completion_j): dsu.merge(i, j)
    
    # Compute majority votes
    max_prob, max_prob_count = 0, 0
    for i in range(m):
        if dsu.get_father(i) != i: continue
        prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
        if prob_i > max_prob: 
            max_prob = prob_i
            max_prob_count = 0
        if prob_i >= max_prob: max_prob_count += 1
    
    # Compute accuracy
    correct, answers = 0, []
    for i in range(m):
        if dsu.get_father(i) != i: continue
        ans_i = predicts[i]
        prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
        answers.append([ans_i, prob_i, check_equal(ans_i, answer)])
        if prob_i < max_prob: continue
        if check_equal(ans_i, answer): correct += 1.0 / max_prob_count
    
    # Normalize probabilities
    sum_proba = np.sum([x[1] for x in answers])
    for i in range(len(answers)):
        answers[i][1] /= sum_proba

    return correct, answers

class CPCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "CPC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=completion_compare, 
                                          evaluator=pc_evaluator, title=self.name)
    
class NPCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "NPC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=numberic_compare, 
                                          evaluator=pc_evaluator, title=self.name)
    
class APCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "APC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=answer_compare, 
                                          evaluator=pc_evaluator, title=self.name)
    
class CUMPCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "CUMPC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=completion_compare, 
                                          evaluator=umpc_evaluator, title=self.name)
    
class NUMPCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "NUMPC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=numberic_compare, 
                                          evaluator=umpc_evaluator, title=self.name)
    
class AUMPCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "AUMPC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=answer_compare, 
                                          evaluator=umpc_evaluator, title=self.name)
    
if __name__ == "__main__": 
    # CPCEvaluator().process()
    NPCEvaluator().process()
    # APCEvaluator().process()
    # CUMPCEvaluator().process()
    NUMPCEvaluator().process()
    # AUMPCEvaluator().process()
        