import argparse
import json
import numpy as np
import torch
import tqdm
import os
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.calibration import calibration_curve
from sklearn.metrics import f1_score, roc_curve, accuracy_score, roc_auc_score
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import cm
import json
import os

from sklearn.model_selection import train_test_split

def get_best_f1(true_labels, scores):
    fpr, tpr, thresholds = roc_curve(true_labels, scores)
    f1_scores = [f1_score(true_labels, (scores >= t).astype(int)) for t in thresholds]
    best_threshold_index = np.argmax(f1_scores)
    best_f1 = f1_scores[best_threshold_index]
    return best_f1

def get_accuracy(true_labels, scores):
    fpr, tpr, thresholds = roc_curve(true_labels, scores)
    f1_scores = [f1_score(true_labels, (scores >= t).astype(int)) for t in thresholds]
    best_threshold_index = np.argmax(f1_scores)
    best_threshold = thresholds[best_threshold_index]
    y_pred = (scores >= best_threshold).astype(int)
    accuracy = accuracy_score(true_labels, y_pred)
    return accuracy

def expected_calibration_error(y_true, y_prob, n_bins=10):
    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')
    bin_counts = np.histogram(y_prob, bins=n_bins, range=(0, 1))[0]
    total_samples = len(y_true)
    bin_proportions = bin_counts / total_samples
    abs_errors = np.abs(prob_true - prob_pred)
    abs_errors_padded = np.pad(abs_errors, (0, len(bin_proportions) - len(abs_errors)), mode='constant', constant_values=0)
    ece = np.sum(abs_errors_padded * bin_proportions)
    return ece

def read_json_file(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def parse_arguments():
    parser = argparse.ArgumentParser(description="Evaluate model performance.")
    parser.add_argument("--data_read_path", type=str, required=True, help="Path to the data JSON file")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
    parser.add_argument("--path_before", type=str, required=True, help="Path to the previous scores JSON file")
    parser.add_argument("--result_save_path", type=str, required=True, help="Path to the previous scores JSON file")
    parser.add_argument("--uncertainty_save_path", type=str, required=True, help="Path to the previous scores JSON file")
    parser.add_argument("--data_model", type=str, required=True, help="")

    
    return parser.parse_args()

def main():
    args = parse_arguments()
    
    data = read_json_file(args.data_read_path)
    
    folder_path = Path(args.model_path)
    subfolders = [f for f in folder_path.iterdir() if f.is_dir()]
    adds=[]
    for path in subfolders:
        print(path)
        model_path = path
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        batch_size = 8
        labels = ['correct']
        id2label = {idx: label for idx, label in enumerate(labels)}
        label2id = {label: idx for idx, label in enumerate(labels)}
        model = AutoModelForSequenceClassification.from_pretrained(model_path, problem_type="multi_label_classification",
                                                                   num_labels=1, id2label=id2label, label2id=label2id)
        uncertainty_save_path=args.uncertainty_save_path+str(path).replace('/','_')+"_uncertainty.json"
        print(uncertainty_save_path)
        if not os.path.exists(uncertainty_save_path):
            predictions = []
            for item in tqdm.tqdm(data):
                text = item['text']
                tokens = tokenizer.encode(text, add_special_tokens=False)
                length = len(tokens)
                if length > 512:
                    last_512_tokens = tokens[-512:]
                    decoded_text = tokenizer.decode(last_512_tokens, skip_special_tokens=True)
                    text = decoded_text
                encoding = tokenizer(text, truncation=True, max_length=512, return_tensors='pt')
                outputs = model(**encoding)
                logits = outputs.logits
                sigmoid = torch.nn.Sigmoid()
                probs = sigmoid(logits.squeeze().cpu())
                prediction = 1 if probs > 0.5 else 0
                pred = {'prompt': text, 'prediction': prediction, 'probs': float(probs), 'data_uncertainty': 1 - float(probs)}
                predictions.append(pred)
            with open(uncertainty_save_path, 'w') as file:
                    json.dump(predictions, file, indent=4)
        else :
            with open(uncertainty_save_path, 'r') as file:
                predictions = json.load(file)
        
        with open(args.path_before, 'r') as file:
            data_old = json.load(file)
        
        true_labels = [1 - item['correct'] for item in data]
        
        add=0
        for j in range(len(data_old)):
            scores_sar = data_old[j]["scores"]
            scores_new = [entry["data_uncertainty"] for entry in predictions]

            # auc_sar = roc_auc_score(true_labels, scores_sar)
            # auc_new = roc_auc_score(true_labels, scores_new)
            # ece_sar = expected_calibration_error(true_labels, scores_sar)
            # ece_new = expected_calibration_error(true_labels, scores_new)
            # f1_sar = get_best_f1(true_labels, scores_sar)
            # f1_new = get_best_f1(true_labels, scores_new)
            # accuracy_sar = get_accuracy(true_labels, scores_sar)
            # accuracy_new = get_accuracy(true_labels, scores_new)

            # print(f"({data_old[j]['method']}) AUC: {auc_sar:.4f}\n"
            #       f"(training) AUC: {auc_new:.4f}\n"
            #       f"({data_old[j]['method']}) ECE: {ece_sar:.4f}\n"
            #       f"(training) ECE: {ece_new:.4f}\n"
            #       f"({data_old[j]['method']}) f1: {f1_sar:.4f}\n"
            #       f"(training) f1: {f1_new:.4f}\n"
            #       f"({data_old[j]['method']}) accuracy: {accuracy_sar:.4f}\n"
            #       f"(training) accuracy: {accuracy_new:.4f}\n")
            # result_save_path=args.result_save_path

            # # 使用 'with' 语句打开文件，以确保文件正确关闭
            # with open(result_save_path, "a") as file:
            #     # 将结果写入文件
            #     file.write(f"({data_old[j]['method']}) AUC: {auc_sar:.4f}\n"
            #             f"(training) AUC: {auc_new:.4f}\n"
            #             f"({data_old[j]['method']}) ECE: {ece_sar:.4f}\n"
            #             f"(training) ECE: {ece_new:.4f}\n"
            #             f"({data_old[j]['method']}) f1: {f1_sar:.4f}\n"
            #             f"(training) f1: {f1_new:.4f}\n"
            #             f"({data_old[j]['method']}) accuracy: {accuracy_sar:.4f}\n"
            #             f"(training) accuracy: {accuracy_new:.4f}\n"
            #             f"\n\n\n")
            auc_max = 0.0
            weight_opt = 0.0
            # for i in range(1000):
            #     weight_new = float(i) / 1000
            #     weight_sar = 1 - weight_new
            #     new_scores = weight_new * np.array(scores_new) + weight_sar * np.array(scores_sar)
            #     auc_temp = roc_auc_score(true_labels, new_scores)
            #     if auc_temp > auc_max:
            #         auc_max = auc_temp
            #         weight_opt = weight_new
            # new_scores = weight_opt * np.array(scores_new) + (1 - weight_opt) * np.array(scores_sar)
            # ece_new = expected_calibration_error(true_labels, new_scores)
            # f1_new = get_best_f1(true_labels, new_scores)
            # accuracy_new = get_accuracy(true_labels, new_scores)
            
            # print(f"MAX AUC: {auc_max:.4f}\n"
            #       f"optimal weight: {weight_opt:.4f}\n"
            #       f"ece_new: {ece_new:.4f}\n"
            #       f"f1_new: {f1_new:.4f}\n"
            #       f"accuracy_new: {accuracy_new:.4f}\n")
            # with open(result_save_path, "a") as file:
            #     # 将结果写入文件
            #     file.write(f"MAX AUC: {auc_max:.4f}\n"
            #       f"optimal weight: {weight_opt:.4f}\n"
            #       f"ece_new: {ece_new:.4f}\n"
            #       f"f1_new: {f1_new:.4f}\n"
            #       f"accuracy_new: {accuracy_new:.4f}\n")



            true_labels_dev, true_labels_test, scores_new_dev, scores_new_test, scores_sar_dev, scores_sar_test = train_test_split(
                true_labels, scores_new, scores_sar, test_size=0.5, random_state=12
            )

            auc_max = -np.inf
            weight_opt = 0
            result_save_path=args.result_save_path
            with open(result_save_path, "a") as file:
                file.write(f"{path}###########{data_old[j]['method']}\n"
                        f"\n\n\n")
            for i in range(1000):
                weight_new = float(i) / 1000
                weight_sar = 1 - weight_new
                new_scores_dev = weight_new * np.array(scores_new_dev) + weight_sar * np.array(scores_sar_dev)
                auc_temp = roc_auc_score(true_labels_dev, new_scores_dev)
                if auc_temp > auc_max:
                    auc_max = auc_temp
                    weight_opt = weight_new
            auc_sar = roc_auc_score(true_labels_test, scores_sar_test)
            auc_new = roc_auc_score(true_labels_test, scores_new_test)
            ece_sar = expected_calibration_error(true_labels_test, scores_sar_test)
            ece_new = expected_calibration_error(true_labels_test, scores_new_test)
            f1_sar = get_best_f1(true_labels_test, scores_sar_test)
            f1_new = get_best_f1(true_labels_test, scores_new_test)
            accuracy_sar = get_accuracy(true_labels_test, scores_sar_test)
            accuracy_new = get_accuracy(true_labels_test, scores_new_test)

            print(f"({data_old[j]['method']}) AUC: {auc_sar:.4f}\n"
                  f"(training) AUC: {auc_new:.4f}\n"
                  f"({data_old[j]['method']}) ECE: {ece_sar:.4f}\n"
                  f"(training) ECE: {ece_new:.4f}\n"
                  f"({data_old[j]['method']}) f1: {f1_sar:.4f}\n"
                  f"(training) f1: {f1_new:.4f}\n"
                  f"({data_old[j]['method']}) accuracy: {accuracy_sar:.4f}\n"
                  f"(training) accuracy: {accuracy_new:.4f}\n")
           


            with open(result_save_path, "a") as file:
                file.write(f"({data_old[j]['method']}) AUC: {auc_sar:.4f}\n"
                        f"(training) AUC: {auc_new:.4f}\n"
                        f"({data_old[j]['method']}) ECE: {ece_sar:.4f}\n"
                        f"(training) ECE: {ece_new:.4f}\n"
                        f"({data_old[j]['method']}) f1: {f1_sar:.4f}\n"
                        f"(training) f1: {f1_new:.4f}\n"
                        f"({data_old[j]['method']}) accuracy: {accuracy_sar:.4f}\n"
                        f"(training) accuracy: {accuracy_new:.4f}\n"
                        f"\n\n\n")

            new_scores_all = weight_opt * np.array(scores_new) + (1 - weight_opt) * np.array(scores_sar)
            directory="path/uncertainty/0_result/jiaozhun/"+data_old[j]['method']+str(path).replace('/','_')
            if not os.path.exists(directory):
                os.makedirs(directory)
            with open("path/uncertainty/0_result/jiaozhun/"+data_old[j]['method']+str(path).replace('/','_')+"_uncertainty.json", "w") as file:
                json.dump(new_scores_all.tolist(), file)

            uncertainty_input = "path/uncertainty/0_result/jiaozhun/"+data_old[j]['method']+str(path).replace('/','_')+"_uncertainty.json"
            with open(uncertainty_input, 'r') as f:
                uncertainty_data = json.load(f)
            label_input = args.data_read_path
            with open(label_input, 'r') as f:
                label_data = json.load(f)
            
            label_data_list = [label_data[i]['correct'] for i in range(len(label_data))]
            # Assuming the JSON structure has "confidence_scores" and "accuracies" as keys
            confidence_scores = np.array(uncertainty_data)
            confidence_scores=1-confidence_scores
            accuracies = np.array(label_data_list)

            # Define bins and bin the confidence scores
            bins = np.linspace(0, 1, 11)  # e.g., [0, 0.1, 0.2, ..., 1.0]
            bin_indices = np.digitize(confidence_scores, bins)

            # Calculate accuracy for each bin
            bin_acc = []
            bin_sizes = []
            for i in range(1, len(bins)):
                bin_data = accuracies[bin_indices == i]
                bin_sizes.append(len(bin_data))
                bin_acc.append(np.mean(bin_data) if len(bin_data) > 0 else 0)

            # Create a bar chart
            plt.figure(figsize=(6, 6))
            width = (bins[1] - bins[0]) * 0.9  # Reduce width for gaps

            # Create a color gradient with smaller increments
            num_colors = len(bin_acc)
            colors = cm.Blues(np.linspace(0.1, 0.5, num_colors))  # Adjust range for smaller gradient

            # Adjust the x position of bars to shift them to the right
            shift = 0.005  # Set the shift amount
            bars = plt.bar(bins[:-1] + shift, bin_acc, width=width, align='edge', color=colors, edgecolor='black')

            # Annotate each bar with its percentage, rotated vertically and positioned at the upper part of the bar
            # for bar in bars:
            #     yval = bar.get_height()
            #     plt.text(bar.get_x() + bar.get_width()/2, yval - 0.14, f'{yval * 100:.2f} %', 
            #              ha='center', va='bottom', rotation=90, color='black')

            # Add a diagonal line for reference in dark gray
            plt.plot([0, 1], [0, 1], 'darkgray', linewidth=1, linestyle='--')

            # Add grid lines in light gray
            plt.grid(color='lightgray', linestyle='--', linewidth=0.5)

            # Add labels and title
            plt.xlabel('Confidence')
            plt.ylabel('Accuracy')
            plt.ylim(0, 1)
            plt.xlim(0, 1)  # Ensure x-axis is between 0 and 1

            # Set x-axis and y-axis ticks for every two units
            plt.xticks(np.arange(0, 1.1, 0.2))  # Adjust x-axis ticks to every 0.2
            plt.yticks(np.arange(0, 1.1, 0.2))  # Adjust y-axis ticks to every 0.2

            # plt.title('Accuracy by Confidence Intervals')
            output_file =  "path/uncertainty/0_result/jiaozhun/"+data_old[j]['method']+str(path).replace('/','_')+"_uncertainty.png"  # Replace with the desired output file path
            plt.savefig(output_file, format='png', bbox_inches='tight')

            # plt.show()




            new_scores_test = weight_opt * np.array(scores_new_test) + (1 - weight_opt) * np.array(scores_sar_test)
            auc_max=roc_auc_score(true_labels_test, new_scores_test)
            ece_new = expected_calibration_error(true_labels_test, new_scores_test)
            f1_new = get_best_f1(true_labels_test, new_scores_test)
            accuracy_new = get_accuracy(true_labels_test, new_scores_test)

            print(f"MAX AUC: {auc_max:.4f}\n"
                f"optimal weight: {weight_opt:.4f}\n"
                f"ece_new: {ece_new:.4f}\n"
                f"f1_new: {f1_new:.4f}\n"
                f"accuracy_new: {accuracy_new:.4f}\n")

            result_save_path =args.result_save_path  

            add+=(auc_max-auc_sar)

            with open(result_save_path, "a") as file:
                file.write(f"MAX AUC: {auc_max:.4f}\n"
                        f"optimal weight: {weight_opt:.4f}\n"
                        f"ece_new: {ece_new:.4f}\n"
                        f"f1_new: {f1_new:.4f}\n"
                        f"accuracy_new: {accuracy_new:.4f}\n"
                        f"\n\n\n")
                file.write(f"weight: {weight_opt:.4f}\n"
                        f"add_AUC: {auc_max-auc_sar:.4f}\n"
                        f"add_ece: {ece_new-ece_sar:.4f}\n"
                        f"add_f1: {f1_new-f1_sar:.4f}\n"
                        f"add_accuracy: {accuracy_new-accuracy_sar:.4f}\n"
                        f"\n\n\n")         
        adds.append({"pici":path,"addation":add})
    with open(result_save_path, "a") as file:
            file.write(f"{adds}\n")

if __name__ == "__main__":
    main()
