import argparse
import json
import re, os
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
import numpy as np
import jsonlines
from tqdm import trange
import random
import torch
import matplotlib.pyplot as plt
import multiprocessing
import time
import signal, functools
import metrics

__all__ = ["check_equal", "check_equal_without_timeout", "completion_compare", "answer_compare", "numberic_compare", "Evaluator"]

font = {'family': 'serif', 'serif': 'Times New Roman', 'weight': 'normal', 'size': 12}
plt.rc('font', **font)
plt.rc('mathtext', default='regular')
plt.rcParams['savefig.dpi'] = 256
plt.rcParams['figure.dpi'] = 64
symbols = ["-^", "-o", "-p", "-s", '-v', '-+', '-h',"-s", "-<", "->", "-p", "-x", '-v', '-.', '-']
colors = ["#82B0D2", "#8ECFC9",   "#FFBE7A", "#FA7F6F",  "#BEB8DC", "#E7DAD2"]

def timeout(sec):
    """
    timeout decorator
    :param sec: function raise TimeoutError after ? seconds
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapped_func(*args, **kwargs):

            def _handle_timeout(signum, frame):
                err_msg = f'Function {func.__name__} timed out after {sec} seconds'
                raise TimeoutError(err_msg)

            signal.signal(signal.SIGALRM, _handle_timeout)
            signal.alarm(sec)
            try:
                result = func(*args, **kwargs)
            finally:
                signal.alarm(0)
            return result

        return wrapped_func
    return decorator

from data_processing.answer_extraction import *
from eval.eval_script import *

MAX_INT = sys.maxsize
INVALID_ANS = "[Invalid]"
INF = 1e9
PROCESS = 96

@timeout(5)
def check_equal_without_timeout(ans_1, ans_2):
    return math_equal(ans_1, ans_2)

def check_equal(ans_1, ans_2):
    try:
        return check_equal_without_timeout(ans_1, ans_2)
    except TimeoutError as e:
        return False

def completion_compare(ai, aj, ci, cj):
    return ci == cj

def answer_compare(ai, aj, ci, cj):
    return ai == aj

def numberic_compare(ai, aj, ci, cj):
    return check_equal(ai, aj)

def prep_evaluator(predicts, completions, perplexities, answer, equal_func, K=MAX_INT):
    m = min(len(predicts), K)

    # Compute maximum probability
    max_perplexity = -INF
    max_perplexity_count = 0
    for i in range(m):
        if perplexities[i] > max_perplexity:
            max_perplexity = perplexities[i]
            max_perplexity_count = 0
        if perplexities[i] >= max_perplexity:
            max_perplexity_count += 1
    
    # Compute accuracy
    correct, answers = 0, []
    for i in range(m):
        ans_i = predicts[i]
        answers.append([ans_i, np.exp(perplexities[i]), check_equal(ans_i, answer)])
        if perplexities[i] < max_perplexity: continue
        if check_equal(ans_i, answer): correct += 1.0 / max_perplexity_count

    # Normalize probabilities
    # sum_proba = np.sum([x[1] for x in answers])
    # for i in range(len(answers)):
    #     answers[i][1] /= sum_proba
    
    return correct, answers

class Evaluator:
    def __init__(self,):
        self.name = "Perp"

    def multi_process_compute(self, path, start, end, K, equal_func, evaluator, title):
        start_time = time.time()
        with open(path, "r") as fr:
            results = json.load(fr)
        
        n = len(results["predict"])
        pool = multiprocessing.Pool(processes=PROCESS)

        print(title, min(end, n) - max(0, start))
        
        # Prepare parameters
        parameters = []
        for idx in range(max(0, start), min(end, n)):
            parameters.append((results["predict"][idx], results["completion"][idx], results["mean_logprob"][idx], results["answer"][idx], equal_func, K))

        # Process in parallel
        outputs = pool.starmap(evaluator, parameters)
        maximum, max_bins = metrics.compute_maximum_metrics([x[1] for x in outputs])
        average, avg_bins = metrics.compute_average_metrics([x[1] for x in outputs])
        accs = np.mean([x[0] for x in outputs])
        
        pool.close()
        pool.join()
        ret = "%10s (%.2fs) => %.4f%%" % (title, time.time() - start_time,  accs * 100.0)
        
        return accs * 100.0, maximum, average, max_bins, avg_bins

    def parse_args(self, ):
        parser = argparse.ArgumentParser()
        parser.add_argument("--path", type=str)
        parser.add_argument("-K", type=int, default=MAX_INT)
        parser.add_argument("--start", type=int, default=0)
        parser.add_argument("--end", type=int, default=MAX_INT)
        return parser.parse_args()

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=numberic_compare, 
                                          evaluator=prep_evaluator, title=self.name)
    
    def plot_ece(self, ax, siz, acc, conf, ece, title, n_bins=10):
        x_bins = (np.arange(n_bins) + 0.5) / n_bins
        ax.bar(x_bins, acc * 100.0, width=1.0 / n_bins, linewidth=2, edgecolor="black", color="#82B0D2", label="Predict")
        ax.bar(x_bins, np.abs(acc - conf) * 100.0, bottom=np.minimum(acc, conf) * 100.0, width=1.0 / n_bins, linewidth=2, edgecolor=(1, 0, 0), facecolor="#FA7F6F", label="Gap")
        ax.plot([0,1], [0,100], lw=2, linestyle = ":", color="black")
        ax.set_xlim([0, 1])
        ax.set_ylim([0, 100])
        ax.set_ylabel('Accuracy (%)')
        ax.set_title(f'{title} ECE = {ece * 100.0:.2f}%', fontsize=16)
        ax.legend()

    def plot_count(self, ax, siz, acc, conf, ece, n_bins=10):
        x_bins = (np.arange(n_bins) + 0.5) / n_bins
        ax.bar(x_bins, -siz/np.sum(siz)*100.0, width=1.0 / n_bins, color="#8ECFC9")
        avg_acc = np.sum(acc * siz) / np.sum(siz)
        avg_conf = np.sum(conf * siz) / np.sum(siz)
        ax.axvline(x=avg_acc, ls="solid", lw=2,  c="black", label="Avg. Acc.")
        ax.axvline(x=avg_conf, ls="dotted", lw=2, c="black", label="Avg. Conf.")
        new_ticks = np.abs(ax.get_yticks()).astype(int)
        ax.set_yticklabels(new_ticks)    
        ax.set_ylabel('Proportion (%)')
        ax.set_xlabel('Confidence')
        ax.legend()
    
    def process(self,):
        self.args = self.parse_args()
        args = self.args
        start_time = time.time()
        acc, maximum, average, max_bins, avg_bins = self.interface()

        # Save figures
        fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(4, 6), gridspec_kw={"height_ratios": [2, 1]})
        self.plot_ece(axes[0], max_bins[2], max_bins[0], max_bins[1], maximum[0], "")
        self.plot_count(axes[1], max_bins[2], max_bins[0], max_bins[1], maximum[0])
        # self.plot_ece(axes[0, 1], avg_bins[2], avg_bins[0], avg_bins[1], average[0], "Average")
        # self.plot_count(axes[1, 1], avg_bins[2], avg_bins[0], avg_bins[1], average[0])
        # plt.tight_layout()
        plt.savefig(args.path.replace(".json", "-" + self.name + ("" if args.K == MAX_INT else f"-{args.K}") + ".pdf"), bbox_inches='tight', pad_inches=0)


        # Save results
        result_path = os.path.join(os.path.dirname(args.path), "results.json" if args.K == MAX_INT else "results-%d.json" % args.K)
        try:
            with open(result_path, "r") as fr:
                results = json.load(fr)
        except:
            results = {}
        if args.path not in results: results[args.path] = {}
        if self.name not in results[args.path]: results[args.path][self.name] = {}
        results[args.path][self.name] = {
            "Acc": acc, 
            "Max_ECE": maximum[0], "Max_BS":  maximum[1], "Max_NLL": maximum[2],
            "Avg_ECE": average[0], "Avg_BS":  average[1], "Avg_NLL": average[2],
        }
        print(self.name, args.path, results[args.path][self.name], time.time() - start_time)
        with open(result_path, "w") as fw:
            json.dump(results, fw)

if __name__ == "__main__": 
    Evaluator().process()
        