import argparse
import json
import re, os
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
import numpy as np
import jsonlines
from tqdm import trange
import random
import torch
import matplotlib.pyplot as plt
import multiprocessing
import time
import signal, functools
import metrics
from compute_perp import *
MAX_INT = sys.maxsize
INVALID_ANS = "[Invalid]"
PROCESS = 60

__all__ = ["DSU"]

class DSU:
    def __init__(self, n):
        self.n = n
        self.father = [i for i in range(n)]
        self.size = [1 for i in range(n)]
        self.attr = [{} for i in range(n)]
    
    def get_father(self, x):
        if self.father[x] == x: return x
        self.father[x] = self.get_father(self.father[x])
        return self.father[x]
    
    def merge(self, x, y):
        fx = self.get_father(x)
        fy = self.get_father(y)
        if fx == fy: return
        # Commented for parallel
        # if fx > fy: fx, fy = fy, fx 
        self.father[fy] = fx
        # Update size of DSU
        self.size[fx] += self.size[fy]
        self.size[fy] = 0
        # Merge attribute dict
        for key in self.attr[fy].keys():
            if key not in self.attr[fx]: 
                self.attr[fx][key] = self.attr[fy][key]
            else:
                self.attr[fx][key] |= self.attr[fy][key]
        self.attr[fy] = {}

def sc_evaluator(predicts, completions, perplexities, answer, equal_func, K=MAX_INT):
    m = min(len(predicts), K)
    dsu = DSU(m)

    # Merge answer for self-consistency
    for i in range(m):
        if dsu.get_father(i) != i: continue
        for j in range(i):
            ans_i = predicts[i]
            ans_j = predicts[j]
            completion_i = completions[i]
            completion_j = completions[j]
            if equal_func(ans_i, ans_j, completion_i, completion_j): dsu.merge(i, j)
    
    # Compute majority votes
    max_size, max_size_count = 0, 0
    for i in range(m):
        if dsu.get_father(i) != i: continue
        if dsu.size[i] > max_size: 
            max_size = dsu.size[i]
            max_size_count = 0
        if dsu.size[i] == max_size: max_size_count += 1
    
    # Compute accuracy
    correct, answers = 0, []
    for i in range(m):
        if dsu.get_father(i) != i: continue
        ans_i = predicts[i]
        answers.append([ans_i, dsu.size[i] / m, check_equal(ans_i, answer)])
        if dsu.size[i] < max_size: continue
        if check_equal(ans_i, answer): correct += 1.0 / max_size_count
    
    # Normalize probabilities
    sum_proba = np.sum([x[1] for x in answers])
    for i in range(len(answers)):
        answers[i][1] /= sum_proba

    return correct, answers

class NSCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "NSC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=numberic_compare, 
                                          evaluator=sc_evaluator, title=self.name)
    
class ASCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "ASC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=answer_compare, 
                                          evaluator=sc_evaluator, title=self.name)
    
class CSCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "CSC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=completion_compare, 
                                          evaluator=sc_evaluator, title=self.name)
    
if __name__ == "__main__": 
    CSCEvaluator().process()
    NSCEvaluator().process()
    # ASCEvaluator().process()
        