import argparse
import json
import re, os
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
import numpy as np
import jsonlines
from tqdm import trange
import random
import torch
import matplotlib.pyplot as plt
import multiprocessing
import time
import signal, functools
import metrics
from compute_perp import *
from compute_sc import *

MAX_INT = sys.maxsize
INVALID_ANS = "[Invalid]"
PROCESS = 60
EPS = 1e-6

from scipy.stats import weibull_min
import numpy as np
import scipy.stats as stats
from scipy.optimize import minimize
from scipy.special import gamma

def weibull_pdf(x, k, lam):
    return (k / lam) * (x / lam)**(k - 1) * np.exp(-(x / lam)**k)

def weibull_mean(k, lam):
    return lam * gamma(1 + 1 / k)

def mixture_pdf(x, w1, k1, lam1, k2, lam2):
    return w1 * weibull_pdf(x, k1, lam1) + (1 - w1) * weibull_pdf(x, k2, lam2)

def neg_log_likelihood(params, data):
    w1, k1, lam1, k2, lam2 = params
    pdf_vals = mixture_pdf(data, w1, k1, lam1, k2, lam2)
    return -np.sum(np.log(pdf_vals))

def calculate_membership_probabilities(data, w1, k1, lam1, k2, lam2):
    pdf1 = weibull_pdf(data, k1, lam1)
    pdf2 = weibull_pdf(data, k2, lam2)
    
    prob1 = w1 * pdf1 / (w1 * pdf1 + (1 - w1) * pdf2)
    
    prob2 = 1 - prob1
    
    return prob1, prob2

def wpc_evaluator(predicts, completions, perplexities, answer, equal_func, K=MAX_INT):
    m = min(len(predicts), K)
    dsu = DSU(m)

    probas = [np.exp(perplexities[i]) for i in range(m)]
    mean_proba = np.mean(probas)

    initial_guess = [0.5, 1.0, 1.0, 1.5, 2.0]  
    result = minimize(neg_log_likelihood, initial_guess, args=(probas,), bounds=[(0.2, 0.8), (0.01, None), (0.01, None), (0.01, None), (0.01, None)])
    w1, k1, lam1, k2, lam2 = result.x

    w1, k1, lam1, k2, lam2 = result.x
    if weibull_mean(k1, lam1) < weibull_mean(k2, lam2):
        k1, lam1, k2, lam2 = k2, lam2, k1, lam1
        w1 = 1 - w1

    remove = 0
    for i in range(m):
        completion_i = completions[i]
        logprob_i = perplexities[i]
        proba_i = np.exp(logprob_i)
        p1, p2 = calculate_membership_probabilities(proba_i, w1, k1, lam1, k2, lam2)
        if p1 < p2 and proba_i < mean_proba: 
            proba_i = 0
            remove += 1
        else:
            dsu.attr[i][completion_i] = set([proba_i])
    print("Remove", remove / len(probas), w1, k1, lam1, k2, lam2, p1, p2)

    # Merge answer for self-consistency
    for i in range(m):
        if dsu.get_father(i) != i: continue
        for j in range(i):
            ans_i = predicts[i]
            ans_j = predicts[j]
            completion_i = completions[i]
            completion_j = completions[j]
            if equal_func(ans_i, ans_j, completion_i, completion_j): dsu.merge(i, j)
    
    # Compute majority votes
    max_prob, max_prob_count = 0, 0
    for i in range(m):
        if dsu.get_father(i) != i: continue
        prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
        if prob_i > max_prob: 
            max_prob = prob_i
            max_prob_count = 0
        if prob_i >= max_prob: max_prob_count += 1
    
    # Compute accuracy
    correct, answers = 0, []
    for i in range(m):
        if dsu.get_father(i) != i: continue
        ans_i = predicts[i]
        prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
        answers.append([ans_i, prob_i, check_equal(ans_i, answer)])
        if prob_i < max_prob: continue
        if check_equal(ans_i, answer): correct += 1.0 / max_prob_count
    
    # Normalize probabilities
    sum_proba = np.sum([x[1] for x in answers])
    # print(sum_proba)
    for i in range(len(answers)):
        answers[i][1] /= sum_proba

    return correct, answers

class CDWPCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "CDWPC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=completion_compare, 
                                          evaluator=wpc_evaluator, title=self.name)
    
class NDWPCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "NDWPC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=numberic_compare, 
                                          evaluator=wpc_evaluator, title=self.name)
    
class ADWPCEvaluator(Evaluator):
    def __init__(self,):
        self.name = "ADWPC"

    def interface(self):
        return self.multi_process_compute(path=self.args.path, 
                                          start=self.args.start, end=self.args.end, K=self.args.K, 
                                          equal_func=answer_compare, 
                                          evaluator=wpc_evaluator, title=self.name)
    

    
if __name__ == "__main__": 
    NDWPCEvaluator().process()
        