
import numpy as np
import pickle
import os
import sys
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import copy
import json

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from utils import get_metric
from semantic_functions import get_similarity_matrix_from_logits, get_similarity_matrix_from_logits_cross
from sklearn.metrics import roc_auc_score


def rectangular_rule(y_values, x_values):

    area = 0
    for i in range(len(x_values)):
        # Width of the rectangle
        if i == len(x_values) - 1:
            width = 1/len(x_values)
        else:
            width = x_values[i + 1] - x_values[i]
        # Height of the rectangle (use the y-value at the left endpoint)
        height = y_values[i]
        # Add the area of the rectangle to the total area
        area += width * height
    return area


def area_under_curve(points, metric):
    """
    Calculate the area under the curve using the trapezoidal rule.
    
    Parameters:
    - points: numpy array of shape (N, 2) where each row is a 2D point (x, y).
    
    Returns:
    - area: float, the area under the curve.
    """
    # Sort points by the x-coordinate

    assert metric in ['auroc', 'aurac']

    points = points[np.argsort(points[:, 0])]
    
    x_values = points[:, 0]
    y_values = points[:, 1]
    
    if metric == "auroc":
        # Apply the trapezoidal rule
        area = np.trapz(y_values, x_values)
        return area
    elif metric == "aurac":
        area = rectangular_rule(y_values, x_values)
    return area


def accuracy_at_quantile(accuracies, uncertainties, quantile):
    cutoff = np.quantile(uncertainties, quantile)
    select = uncertainties <= cutoff
    return np.mean(accuracies[select])


def area_under_thresholded_accuracy(errors, uncertainties):
    accuracies = [1-e for e in errors]
    accuracies, uncertainties = np.array(accuracies), np.array(uncertainties)
    quantiles = np.linspace(0, 1, 50)
    select_accuracies = np.array([accuracy_at_quantile(accuracies, uncertainties, q) for q in quantiles])
    dx = quantiles[1] - quantiles[0]
    area = (select_accuracies * dx).sum()
    return area, select_accuracies, quantiles


def find_b(sorted_list, p, t1):

    assert p > 0

    # Total number of elements in the list
    total_elements = len(sorted_list)
    
    # Number of elements required to satisfy the fraction p
    required_count = int(total_elements * p)
    
    # Find the starting point in the sorted list where elements are >= t1
    start_idx = next((i for i, x in enumerate(sorted_list) if x >= t1), len(sorted_list))
    
    # Ensure there are enough elements to satisfy p
    if len(sorted_list) - start_idx < required_count:
        #return sorted_list[-1]
        return None  # Not enough elements to satisfy the fraction
    
    # The element at this index defines b
    b_idx = start_idx + required_count - 1
    return sorted_list[b_idx]


def two_stage_detector(list_score_self, sorted_list_score_self, list_score_cross, p, t1, t2=None, id2=None):

    list_score_self = np.array(list_score_self)
    list_score_cross = np.array(list_score_cross)

    if p == 0:
        prediction = (list_score_self >= t1).astype(int)
        return [list(prediction)], [1000], ["inf"]
    else:
        b = find_b(sorted_list_score_self, p, t1)

        if b is None:
            if t2 is None:
                return None
            else:
                b = np.max(list_score_cross)

        n = len(list_score_cross)
        prediction = np.zeros(n)
        prediction[list_score_self < t1 ] = 0
        prediction[list_score_self > b ] = 1

        indices_middle = np.where((list_score_self >= t1) & (list_score_self <= b))[0]
        
        if t2 is None and (id2 is None):
            if len(indices_middle) > 0:
                all_predictions = []
                all_t2 = list(np.sort(list_score_cross[indices_middle])) + [1000]
                all_id2 = list(range(len(indices_middle))) + ["inf"]
                for id2, t2 in enumerate(all_t2):
                    prediction_tmp = copy.deepcopy(prediction)
                    prediction_tmp[indices_middle] = list_score_cross[indices_middle]>=t2
                    all_predictions.append(prediction_tmp)
                return all_predictions, all_t2, all_id2
            else:
                return [list(prediction)], [1000], ["inf"]
        elif t2 is not None:
            assert id2 is None
            prediction[indices_middle] = list_score_cross[indices_middle]>=t2
            return [list(prediction)], None, None
        else:
            if id2=="inf":
                t2 = 1000
            else:
                if id2 >= len(indices_middle):
                    t2 = np.sort(list_score_cross[indices_middle])[-1]
                else: 
                    t2 = np.sort(list_score_cross[indices_middle])[id2]
            prediction[indices_middle] = list_score_cross[indices_middle]>=t2
            return [list(prediction)], None, None

def compute_tpr_fpr(predictions, labels):
    # Initialize the counts
    true_positive = 0
    false_positive = 0
    true_negative = 0
    false_negative = 0
    
    # Iterate over both predictions and labels
    for pred, label in zip(predictions, labels):
        if label == 1 and pred == 1:
            true_positive += 1  # True positive (TP)
        elif label == 0 and pred == 1:
            false_positive += 1  # False positive (FP)
        elif label == 0 and pred == 0:
            true_negative += 1  # True negative (TN)
        elif label == 1 and pred == 0:
            false_negative += 1  # False negative (FN)

    # Calculate TPR and FPR
    tpr = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
    fpr = false_positive / (false_positive + true_negative) if (false_positive + true_negative) > 0 else 0

    return tpr, fpr


def compute_acc_quantile(predictions, labels):

    n_correct = 0
    n_keep = 0
    for pred, label in zip(predictions, labels):
        if pred == 0:
            n_keep += 1
            if label == 0:
                n_correct += 1
    if n_keep == 0:
        acc = 1
    else:
        acc = n_correct/n_keep
    quantile = n_keep/len(labels)

    return acc, quantile
        

def point_to_line_distance(point, line_start, line_end):
    # Calculate the distance between a point and a line segment (edge of the convex hull)
    line_vec = line_end - line_start
    point_vec = point - line_start
    line_len = np.dot(line_vec, line_vec)
    
    if line_len == 0:
        return np.linalg.norm(point - line_start)  # The line is actually a point, return distance

    projection = np.dot(point_vec, line_vec) / line_len
    projection = np.clip(projection, 0, 1)  # Clipping to ensure point is on the segment
    closest_point = line_start + projection * line_vec
    return np.linalg.norm(point - closest_point)


def select_best_point_indices(fp_tp, n_bins):
    """
    Divide the points into n_bins based on the x-values and return the indices 
    of the points with the highest y-value in each bin.

    Parameters:
    - fp_tp: numpy array of shape (N, 2) where each row is a 2D point (x, y).
    - n_bins: integer, the number of bins to divide the points into.

    Returns:
    - best_indices: a list of indices corresponding to the points with the highest 
      y-value in each bin.
    """
    # Extract x and y values
    x_values = fp_tp[:, 0]
    y_values = fp_tp[:, 1]
    
    # Determine the range for binning
    min_x = np.min(x_values)
    max_x = np.max(x_values)
    
    # Define bin edges
    bin_edges = np.linspace(min_x, max_x, n_bins + 1)
    
    best_indices = []
    # Iterate through each bin to find the best point index
    for i in range(n_bins):
        # Find indices of points that fall within the current bin
        bin_mask = (x_values >= bin_edges[i]) & (x_values < bin_edges[i + 1])
        bin_indices = np.where(bin_mask)[0]
        
        if len(bin_indices) > 0:
            # Identify the index of the point with the highest y-value in the bin
            max_y_idx = bin_indices[np.argmax(y_values[bin_indices])]
            best_indices.append(max_y_idx)
    # make sure to include the end points
    for v in [min_x, max_x]:
        bin_mask = x_values == v
        bin_indices = np.where(bin_mask)[0]
        
        if len(bin_indices) > 0:
            # Identify the index of the point with the highest y-value in the bin
            max_y_idx = bin_indices[np.argmax(y_values[bin_indices])]
            best_indices.append(max_y_idx)

    return best_indices


def compute_auroc_or_aurac(list_score_self, sorted_list_score_self, list_score_cross, p, list_gt, metric, t1_t2=None, id1_id2=None):

    log_plot = False

    n = len(list_score_cross)
    m = int(p*n)

    x_y = []

    if t1_t2 is None and (id1_id2 is None):
        th_given = False
        id_given = False
        t1_t2 = []
        id1_id2 = []
        for id1, t1 in enumerate(sorted_list_score_self.tolist() + [-1000, 1000]):

            if t1 == -1000:
                id1 = "-inf"
            elif t1 == 1000:
                id1 = "inf"

            r = two_stage_detector(list_score_self, sorted_list_score_self, list_score_cross, p, t1, t2=None)
            if r is not None:
                all_predictions, t2s, id2s = r
                for predictions in all_predictions:
                    if metric == "auroc":
                        tpr, fpr = compute_tpr_fpr(predictions, list_gt)
                        x_y.append([fpr, tpr])
                    elif metric == "aurac":
                        acc, quantile = compute_acc_quantile(predictions, list_gt)
                        x_y.append([quantile, acc])
                for t2 in t2s:
                    t1_t2.append([t1, t2])
                for id2 in id2s:
                    id1_id2.append([id1, id2])
    elif t1_t2 is not None:
        assert id1_id2 is None
        th_given = True
        id_given = False
        for t1, t2 in t1_t2:
            r = two_stage_detector(list_score_self, sorted_list_score_self, list_score_cross, p, t1, t2=t2)
            if r is not None:
                all_predictions, t2s, id2s = r
                for predictions in all_predictions:
                    if metric == "auroc":
                        tpr, fpr = compute_tpr_fpr(predictions, list_gt)
                        x_y.append([fpr, tpr])
                    elif metric == "aurac":
                        acc, quantile = compute_acc_quantile(predictions, list_gt)
                        x_y.append([quantile, acc])
    else:
        th_given = False
        id_given = True
        for id1, id2 in id1_id2:

            if id1 == "inf":
                t1 = 1000
            elif id1 == "-inf":
                t1 = -1000
            else:
                t1 = sorted_list_score_self[id1]

            r = two_stage_detector(list_score_self, sorted_list_score_self, list_score_cross, p, t1, t2=None, id2=id2)
            if r is not None:
                all_predictions, t2s, id2s = r
                for predictions in all_predictions:
                    if metric == "auroc":
                        tpr, fpr = compute_tpr_fpr(predictions, list_gt)
                        x_y.append([fpr, tpr])
                    elif metric == "aurac":
                        acc, quantile = compute_acc_quantile(predictions, list_gt)
                        x_y.append([quantile, acc])

    if metric == "auroc":
        x_y.append([0, 0])
        x_y.append([1, 1])
        #x_y.append([1, 0]) TODO: DOUBLE CHECK
    elif metric == "aurac":
        x_y.append([0, 1])

    x_y = np.array(x_y)

    # plt.figure()
    # plt.scatter(fp_tp[:, 0], fp_tp[:, 1])

    if th_given or id_given:
        assert len(x_y)>=1
        vol = area_under_curve(x_y, metric)

        ##########################
        if log_plot:
            plt.scatter(x_y[:, 0], x_y[:, 1])
            aurac, select_accuracies, quantiles = area_under_thresholded_accuracy(list_gt, list_score_cross)
            plt.scatter(quantiles, select_accuracies, color='orange', alpha=0.2)
            plt.title(f"{vol}, aurac cross {aurac}")
            print(f"metric: {vol}")
            tmp_figure_path = f"./logs/{metric}/p={p}_{th_given}.png"
            create_directory_if_not_exists(tmp_figure_path)
            plt.savefig(tmp_figure_path, bbox_inches='tight', dpi=200)
            plt.close()   
        ##########################     

        return vol
    else:
        if metric == "auroc":
            best_indices = select_best_point_indices(x_y[0: -2, :], n_bins=50) #TODO: initially i used 100
        elif metric == "aurac":
            best_indices = select_best_point_indices(x_y[0: -1, :], n_bins=50) 
        
        x_y_best = x_y[best_indices]

        if metric == "auroc":
            new_rows = np.array([[0, 0], [1, 1]])
            x_y_best = np.vstack((x_y_best, new_rows))
        elif metric == "aurac":
            new_rows = np.array([[0, 1]])
            x_y_best = np.vstack((x_y_best, new_rows))
        
        vol = area_under_curve(x_y_best, metric)

        ##########################
        if log_plot:
            plt.scatter(x_y[:, 0], x_y[:, 1])
            plt.scatter(x_y_best[:, 0], x_y_best[:, 1], color='red')
            aurac, select_accuracies, quantiles = area_under_thresholded_accuracy(list_gt, list_score_cross)
            plt.title(f"{vol}, aurac cross {aurac}")
            print(f"metric: {vol}")
            tmp_figure_path = f"./logs/{metric}/p={p}_{th_given}.png"
            create_directory_if_not_exists(tmp_figure_path)
            plt.savefig(tmp_figure_path, bbox_inches='tight', dpi=200)
            plt.close()
        ##########################

        if metric == "auroc":
            x_y = x_y[0: -2, :]
        elif metric == "aurac":
            x_y = x_y[0: -1, :]

        #print([t1_t2[id] for id in best_indices])

        return t1_t2, id1_id2, x_y, vol


def process_csv(file_path):
    # Read the CSV file into a pandas DataFrame
    df = pd.read_csv(file_path)
    
    # Filter rows where the "wd" column has the value "wd"
    filtered_df = df
    
    # Check if there are any rows after filtering
    if filtered_df.empty:
        #print("No rows with wd in the 'wd' column")
        return None
    
    # Compute the mean of the "auroc" and "aurac" columns
    auroc = filtered_df['best_auroc'].max()
    aurac = filtered_df['best_aurac'].max()

    max_idx = filtered_df['best_auroc'].idxmax()  # Get the index of the maximum value in 'best_aurac'
    auroc_std = filtered_df.loc[max_idx, 'final_auroc']
    max_idx = filtered_df['best_aurac'].idxmax()  # Get the index of the maximum value in 'best_aurac'
    aurac_std = filtered_df.loc[max_idx, 'final_aurac']
    return (auroc, auroc_std), (aurac, aurac_std)


def get_GCN_result(result_dir, dataset, target, verifier, n_subset):
    file_path = os.path.join(result_dir, f"{dataset}_{target}_{verifier}_{n_subset}.csv")

    r = process_csv(file_path)
    return r


def create_directory_if_not_exists(file_path):
    # Extract the directory from the file path
    directory = os.path.dirname(file_path)
    
    # Check if the directory exists, and if not, create it
    if not os.path.exists(directory):
        os.makedirs(directory)


def get_average_result(dict_all_results, list_p, metric):

    dict_avg = copy.deepcopy(dict_all_results[0])
    for p in list_p:
        r = [d[p][metric] for d in dict_all_results]
        dict_avg[p][metric] = (np.mean(r), np.std(r))
    return dict_avg


def make_plot(ax, dict_all_results, list_p, metric, color, label=None, annotate_improv=None):
    dict_all_results = get_average_result(dict_all_results, list_p, metric)
    y = [dict_all_results[p][metric][0] for p in list_p]
    yerr = [dict_all_results[p][metric][1] for p in list_p]

    plt.errorbar(list_p, y, yerr=yerr, label=label, color=color)

    if annotate_improv is not None:
        for a in annotate_improv:
            delta = y[-1] - y[0]
            y_tar = y[0] + a*delta
            dist = [np.abs(v-y_tar) for v in y]
            min_index = dist.index(min(dist))
        
            xx, yy = list_p[min_index], y[min_index]

            y_lim = ax.get_ylim()
            x_lim = ax.get_xlim()

            plt.scatter(xx, yy, color=color)  # Plot the point

            # Draw lines to projections
            plt.plot([xx, xx], [y_lim[0], yy], linestyle='--', color=color)  # Line to x-axis
            plt.plot([x_lim[0], xx], [yy, yy], linestyle='--', color=color)  # Line to y-axis

            # ax.set_ylim(y_lim)
            ax.set_xlim(x_lim)


def evaluate(list_p, trial, split, args, hp, dict_hp1_hp2=None):

    assert split in ['train', 'test', 'train_on_test']

    dataset_name = args.dataset + '_train' if split == 'train' else args.dataset
    specs_target_multiple = f'{args.model_target.replace("/", "_")}_p_0.9_temp_1.0_samples_100'
    dirname_target_multiple = f'../answers_{dataset_name}/{specs_target_multiple}'
    list_gt = get_metric(dataset_name, args.model_target, 0.1, 1, "gt")

    if args.cross_type == "multiple":
        specs_verifier = f'{args.model_verifier.replace("/", "_")}_p_0.9_temp_1.0_samples_100'
        name1, name2 = sorted([specs_target_multiple, specs_verifier])
        cross_dir = os.path.join(f'../answers_{dataset_name}/cross', f"{name1}-X-{name2}")

    dict_results = {}

    if split == 'train' or split == 'train_on_test':
        dict_hp1_hp2 = {}

    ##########
    entailment_target_path = os.path.join(dirname_target_multiple, f"random_subsample_labels_logits_and_reps", f"subset_{args.n_subset}", f"trial_{trial}.pickle")
    entailment_cross_path = os.path.join(cross_dir, f"random_subsample_labels_logits_and_reps", f"subset_{args.n_subset}", f"trial_{trial}.pickle")
    
    with open(entailment_target_path, 'rb') as f:
        entailment_target = pickle.load(f)
    with open(entailment_cross_path, 'rb') as f:
        entailment_cross = pickle.load(f)

    list_score_self = []
    list_score_cross = []
    for q, (_, logit_tensor_target, _) in tqdm(entailment_target.items()):
        similarity_matrix_target = get_similarity_matrix_from_logits(logit_tensor_target)

        logit_tensor_pair = {k: v[1] for k, v in entailment_cross[q].items()}

        if args.cross_type == "multiple":
            similarity_matrix_cross = get_similarity_matrix_from_logits_cross(logit_tensor_pair, target_name=specs_target_multiple, verifier_name=specs_verifier)

        list_score_self.append(1-np.mean(similarity_matrix_target))
        list_score_cross.append(1-np.mean(similarity_matrix_cross))

    for p in tqdm(list_p):

        if split in ['train', 'train_on_test']:
            t1_t2, id1_id2, fp_tp, eval_result = compute_auroc_or_aurac(list_score_self, np.sort(list_score_self), list_score_cross, p, list_gt, args.metric)
            best_indices = select_best_point_indices(fp_tp, n_bins=50)

            dict_results[p] = {args.metric: eval_result}
            if hp == "th":
                best_t1_t2 = [t1_t2[id] for id in best_indices]
                dict_hp1_hp2[p] = best_t1_t2
            elif hp == "id":
                best_id1_id2 = [id1_id2[id] for id in best_indices]
                dict_hp1_hp2[p] = best_id1_id2
        else:
            assert dict_hp1_hp2 is not None

            if hp == "th":
                t1_t2 = dict_hp1_hp2[p]
                eval_result = compute_auroc_or_aurac(list_score_self, np.sort(list_score_self), list_score_cross, p, list_gt, args.metric, t1_t2=t1_t2)
            elif hp == "id":
                id1_id2 = dict_hp1_hp2[p]
                eval_result = compute_auroc_or_aurac(list_score_self, np.sort(list_score_self), list_score_cross, p, list_gt, args.metric, id1_id2=id1_id2)
            dict_results[p] = {args.metric: eval_result}

            #print(eval_result)

    return dict_results, dict_hp1_hp2


parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str)
parser.add_argument("--model_target", type=str)
parser.add_argument("--model_verifier", type=str)
parser.add_argument("--n_subset",  type=int, default=10)
parser.add_argument("--cross_type", type=str)
parser.add_argument("--train_type", type=str)
parser.add_argument("--metric", type=str)
args = parser.parse_args()


for key, value in vars(args).items():
    print(f"{key}: {value}")

list_p = [i*0.02 for i in range(50)] + [1]

n_trials = 5

metric = "auroc"
hp = "id"

if args.train_type == 'both':

    results_to_save = {}

    #fig, ax = plt.subplots(figsize=(4, 3))

    annotate_improv = [0.9, 0.95]

    # dict_all_results = []
    # for trial in range(n_trials):
    #     dict_all_results_tmp, _ = evaluate(list_p, trial, 'train_on_test', args)
    #     dict_all_results.append(dict_all_results_tmp)
    # make_plot(dict_all_results, list_p, metric, label='oracle')
    print("*************train on test*************")
    _, dict_hp1_hp2 = evaluate(list_p, n_trials, 'train_on_test', args, hp)
    
    print("*************test*************")
    dict_all_results = []
    for trial in range(n_trials):
        dict_all_results_tmp, _ = evaluate(list_p, trial, 'test', args, hp, dict_hp1_hp2=dict_hp1_hp2)
        dict_all_results.append(dict_all_results_tmp)
    results_to_save['test-test'] = dict_all_results
    #make_plot(ax, dict_all_results, list_p, metric, color='C0', label='test-test', annotate_improv=annotate_improv)

    print("*************train*************")
    _, dict_hp1_hp2 = evaluate(list_p, 0, 'train', args, hp)

    print("*************test*************")
    dict_all_results = []
    for trial in range(n_trials):
        dict_all_results_tmp, _ = evaluate(list_p, trial, 'test', args, hp, dict_hp1_hp2=dict_hp1_hp2)
        dict_all_results.append(dict_all_results_tmp)
    results_to_save['train-test'] = dict_all_results
    #make_plot(ax, dict_all_results, list_p, metric, color='C1', label='train-test', annotate_improv=annotate_improv)

    print(results_to_save['test-test'])

else:
    raise NotImplementedError

################### Uncomment this to read results for approximated ceiling
# name_gcn = 'Supervised GNN'
# result_dir = "../reverse_eng/results_cross/"
# r = get_GCN_result(result_dir, args.dataset, args.model_target.replace("/", "_"), args.model_verifier.replace("/", "_"), args.n_subset)

# if args.metric == "auroc":
#     results_to_save['GNN'] = (r[0][0], r[0][1])
# elif args.metric == "aurac":
#     results_to_save['GNN'] = (r[1][0], r[1][1])
#####################################

path_file = f'./saved_results_{args.train_type}/{args.metric}/{args.dataset}/{args.model_target.replace("/", "_")}/{args.model_verifier.replace("/", "_")}.json'
create_directory_if_not_exists(path_file)

with open(path_file, "w") as json_file:
    json.dump(results_to_save, json_file, indent=4)
