import os
import sys
import re
import pickle
import argparse
import random
import jsonlines
import json

from sklearn.metrics import roc_curve, auc, precision_recall_curve
from scipy import interpolate


current_dir = os.path.dirname(os.path.abspath(__file__))
project_root_path = os.path.abspath(os.path.join(current_dir, "..")) 
sys.path.append(project_root_path)

#from Data.load_data import DatasetInfo
from config_pool import MODEL_POOL, DATASET_POOL, LANGUAGE_MAPPING
from Evaluation.match import AnswerParsing


dataset_path = "/home/hpc/b232dd/b232dd25/CoT-Kinetics/Data/"
class DatasetInfo:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.data = []
        with open(dataset_path + self.dataset_name + ".jsonl", "r+", encoding="utf8") as f:
            for item in jsonlines.Reader(f):
                self.data.append(item)
        self.data_size = len(self.data)


    def load_one_sample(self, idx):
        return self.data[idx]

class StandardEvaluation:
    def __init__(self, dataset_list):
        self.true_samples = []
        self.data_size = 0
        for i, dataset in enumerate(dataset_list):
            # data_loader = DatasetInfo(args.dataset)
            label_loader = DatasetInfo(args.gold_labels)
            self.true_samples.extend(label_loader.data)
            self.data_size += label_loader.data_size

    def std_eval(self, args):
        answerparsing = AnswerParsing(args.dataset.lower())
        with open(args.output_jsonl_path, 'r', encoding='utf-8') as f:
            outputs = [json.loads(line) for line in f]
        valid_ids = {output["id"] for output in outputs}
        true_samples_dict = {sample["id"]: sample for sample in self.true_samples if sample["id"] in valid_ids}

        output_list, binary_list = [], []
        
        acc = 0

        for output in outputs:
            data_id = output["id"]
            print(data_id)

            pred_output = output["output_seq"]

            true_output = true_samples_dict[data_id]["answer"]
            true_sample = true_samples_dict[data_id]

            extracted_answer, binary = answerparsing.dataset_parse(pred_output, true_output, true_sample)
            output['evaluation'] = binary

            if binary: 
                acc += 1

            output_list.append(output)
            binary_list.append(binary)

        return round(acc / len(outputs), 3), output_list, binary_list


class SelfEvaluation:
    def __init__(self, dataset_list):
        self.true_samples = []
        self.data_size = 0
        for i, dataset in enumerate(dataset_list):
            label_loader = DatasetInfo(args.gold_labels)
            # data_loader = DatasetInfo(args.dataset)
            self.true_samples.extend(label_loader.data)
            self.data_size += label_loader.data_size

    def self_eval(self, score_list, binary_list):
        fpr, tpr, thresholds = roc_curve(binary_list, score_list)
        auroc = auc(fpr, tpr)
        fpr95 = float(interpolate.interp1d(tpr, fpr)(0.95))
        precision, recall, _ = precision_recall_curve(binary_list, score_list)
        aupr = auc(recall, precision)

        return round(auroc * 100, 2), round(fpr95 * 100, 2), round(aupr * 100, 2)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="eval")
    parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-1.5Bqwq", choices=MODEL_POOL)
    parser.add_argument("--dataset", type=str, default="GSM8K", choices=DATASET_POOL)
    parser.add_argument("--gold_labels", type=str, default="processed_gsm8k_comma")
    parser.add_argument("--language", type=str, default="en")
    parser.add_argument("--output_jsonl_path", type=str,default="/home/hpc/b232dd/b232dd23/CoT-Kinetics/GSM8K_7B.jsonl")
    parser.add_argument("--updated_jsonl_path", type=str,default="/home/hpc/b232dd/b232dd23/CoT-Kinetics/GSM8K_7B_results.jsonl")


    args = parser.parse_args()

    stdeval = StandardEvaluation([args.dataset])
    acc, output_list, binary_list = stdeval.std_eval(args)
    print(f"# Accuracy: {acc}")

    # write updated output list
    with open(args.updated_jsonl_path, "w", encoding="utf-8") as f:
        for item in output_list:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

    # Prepare self-eval lists
    input_list   = [item["input_seq"] for item in output_list]
    maxprob_list = [item["maxprob"]   for item in output_list]
    ppl_list     = [1 / item["ppl"]    for item in output_list]
    entropy_list = [1 / item["entropy"]for item in output_list]

    random.seed(42)
    random_list = [item["maxprob"] for item in output_list]
    random.shuffle(random_list)

    selfeval = SelfEvaluation([args.dataset])

    # Compute basic metrics
    maxprob_auroc, maxprob_fpr95, maxprob_aupr = selfeval.self_eval(maxprob_list, binary_list)
    ppl_auroc,     ppl_fpr95,     ppl_aupr     = selfeval.self_eval(ppl_list,     binary_list)
    entropy_auroc, entropy_fpr95, entropy_aupr = selfeval.self_eval(entropy_list, binary_list)
    random_auroc,  random_fpr95,  random_aupr  = selfeval.self_eval(random_list,  binary_list)

    # Print results
    print(f"{'maxprob_auroc'.rjust(13)}: {maxprob_auroc:.2f}    "
          f"{'maxprob_fpr95'.rjust(13)}: {maxprob_fpr95:.2f}    "
          f"{'maxprob_aupr'.rjust(13)}: {maxprob_aupr:.2f}")
    print(f"{'ppl_auroc'.rjust(13)}: {ppl_auroc:.2f}    "
          f"{'ppl_fpr95'.rjust(13)}: {ppl_fpr95:.2f}    "
          f"{'ppl_aupr'.rjust(13)}: {ppl_aupr:.2f}")
    print(f"{'entropy_auroc'.rjust(13)}: {entropy_auroc:.2f}    "
          f"{'entropy_fpr95'.rjust(13)}: {entropy_fpr95:.2f}    "
          f"{'entropy_aupr'.rjust(13)}: {entropy_aupr:.2f}")
    print(f"{'random_auroc'.rjust(13)}: {random_auroc:.2f}    "
          f"{'random_fpr95'.rjust(13)}: {random_fpr95:.2f}    "
          f"{'random_aupr'.rjust(13)}: {random_aupr:.2f}")

    # Collect all metrics in a dict
    results = {
        'accuracy': acc,
        'maxprob_auroc':   maxprob_auroc,
        'maxprob_fpr95':   maxprob_fpr95,
        'maxprob_aupr':    maxprob_aupr,
        'ppl_auroc':       ppl_auroc,
        'ppl_fpr95':       ppl_fpr95,
        'ppl_aupr':        ppl_aupr,
        'entropy_auroc':   entropy_auroc,
        'entropy_fpr95':   entropy_fpr95,
        'entropy_aupr':    entropy_aupr,
        'random_auroc':    random_auroc,
        'random_fpr95':    random_fpr95,
        'random_aupr':     random_aupr,
    }

    # Prepare additional metrics
    metrics_to_eval = {
        "CoT-Kinetics":         [item["CoT-Kinetics"] for item in output_list],
        "coe_r":      [item["score_coe_r"] for item in output_list],
        "coe_c":      [item["score_coe_c"] for item in output_list],
    }

    # Compute and print additional metrics, and add to results
    for name, score_list in metrics_to_eval.items():
        auroc, fpr95, aupr = selfeval.self_eval(score_list, binary_list)
        print(f"{name.rjust(13)}_auroc: {auroc:.2f}    "
              f"{name.rjust(13)}_fpr95: {fpr95:.2f}    "
              f"{name.rjust(13)}_aupr: {aupr:.2f}")
        results[f"{name}_auroc"]  = auroc
        results[f"{name}_fpr95"]  = fpr95
        results[f"{name}_aupr"]   = aupr

    # Save all metrics to a JSONL file
    metrics_path = "/home/hpc/b232dd/b232dd23/CoT-Kinetics/GSM8K_7B_results_eval.jsonl"
    with open(metrics_path, 'w', encoding='utf-8') as mf:
        mf.write(json.dumps(results, ensure_ascii=False) + "\n")
    print(f"Saved metrics to {metrics_path}")
