import numpy as np
import torch
from sklearn import metrics
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F


def accuracy_coverage_tradeoff(lab, final_predictions, score):
    thresholds = torch.linspace(0.0, torch.max(score), 100000)
    classification_alignment = torch.eq(final_predictions.long(), lab)

    accur = torch.zeros_like(thresholds)
    cover = torch.zeros_like(thresholds)

    for m, t in enumerate(thresholds):
        classification_alignment_selec = classification_alignment[(score <= t)]
        accuracy = classification_alignment_selec.sum() / len(
            classification_alignment_selec
        )
        coverage = len(classification_alignment_selec) / len(final_predictions)

        accur[m] = accuracy
        cover[m] = coverage

    accur = torch.nan_to_num(accur, nan=1.0)

    fixed_accuracies = np.linspace(0, 1, 101)
    covs_at_fa = []
    accs_at_fa = []
    for fa in fixed_accuracies:
        fa_act, idx = torch.min(torch.abs(accur - fa), dim=0)
        cov_at_fa = cover[idx]
        acc_at_fa = accur[idx]
        covs_at_fa.append(cov_at_fa)
        accs_at_fa.append(acc_at_fa)
        print(
            f"Achieved {cov_at_fa:3.3f} coverage at fixed {fa:3.3f}({acc_at_fa:3.3f}) accuracy"
        )

    fixed_coverages = np.linspace(0, 1, 101)
    covs_at_fc = []
    accs_at_fc = []
    for fc in fixed_coverages:
        fc_act, idx = torch.min(torch.abs(cover - fc), dim=0)
        acc_at_fc = accur[idx]
        cov_at_fc = cover[idx]
        covs_at_fc.append(fc)
        accs_at_fc.append(acc_at_fc)
        print(
            f"Achieved {1 - acc_at_fc:3.3f} error at fixed {fc:3.2f}({cov_at_fc:3.3f}) coverage"
        )

    return (
        np.array(covs_at_fa),
        np.array(accs_at_fa),
        np.array(covs_at_fc),
        np.array(accs_at_fc),
    )


def sn_coverage_accuracy(lab, final_predictions, score):
    classification_alignment = torch.eq(final_predictions.long(), lab)
    classification_alignment_selec = classification_alignment[(score >= 0.5)]
    accuracy = classification_alignment_selec.sum() / len(
        classification_alignment_selec
    )
    coverage = len(classification_alignment_selec) / len(final_predictions)
    return coverage, accuracy


def sn_accuracy_for_coverage(lab, final_predictions, score, fixed_coverage):
    thresholds = torch.linspace(0.0, torch.max(score), 100000)
    classification_alignment = torch.eq(final_predictions.long(), lab)

    accur = torch.zeros_like(thresholds)
    cover = torch.zeros_like(thresholds)

    for m, t in enumerate(thresholds):
        classification_alignment_selec = classification_alignment[(score <= t)]
        accuracy = classification_alignment_selec.sum() / len(
            classification_alignment_selec
        )
        coverage = len(classification_alignment_selec) / len(final_predictions)

        accur[m] = accuracy
        cover[m] = coverage

    accur = torch.nan_to_num(accur, nan=1.0)

    fixed_coverages = [fixed_coverage]
    covs_at_fc = []
    accs_at_fc = []
    for fc in fixed_coverages:
        fc_act, idx = torch.min(torch.abs(cover - fc), dim=0)
        acc_at_fc = accur[idx]
        cov_at_fc = cover[idx]
        covs_at_fc.append(fc)
        accs_at_fc.append(acc_at_fc)
        print(
            f"Achieved {1 - acc_at_fc:3.3f} error at fixed {fc:3.2f}({cov_at_fc:3.3f}) coverage"
        )

    return (
        np.array(covs_at_fc),
        np.array(accs_at_fc),
    )


def calculate_optimal_accuracy(coverage, accuracy):
    opt_acc = np.ones_like(coverage)
    final_acc = accuracy[-1]
    idx = (np.abs(coverage - final_acc)).argmin()
    opt_acc[idx:] = final_acc / coverage[idx:]
    return opt_acc


def calculate_sc_performance(coverage, accuracy):
    opt_acc = calculate_optimal_accuracy(coverage, accuracy)
    opt_delta = opt_acc - accuracy
    opt_delta_metric = metrics.auc(coverage, opt_delta)
    return opt_delta_metric


def save_cov_acc_tradeoff(covs_at_fc, accs_at_fc, args, mode="test"):
    df = pd.DataFrame({"coverage": covs_at_fc, "accuracy": accs_at_fc})
    df.to_csv(f"{args.results_path}cov_acc_{mode}.csv", index=False)


def save_et_vt_scores(e_t_corr_te, v_t_corr_te, args, mode):
    df = pd.DataFrame({"et": e_t_corr_te, "vt": v_t_corr_te})
    df.to_csv(f"{args.results_path}et_vt_{mode}.csv", index=False)


def save_scores(scores, args, mode="test"):
    df = pd.DataFrame({"scores": scores})
    df.to_csv(f"{args.results_path}scores_{mode}.csv", index=False)


def save_targets_preds_scores(
    true_targets_te, predicted_targets_te, scores_targets_te, args
):
    np.save(f"{args.results_path}targets.npy", true_targets_te.cpu().numpy())
    np.save(f"{args.results_path}pred_targets.npy", predicted_targets_te.cpu().numpy())
    np.save(f"{args.results_path}scores_targets.npy", scores_targets_te.cpu().numpy())


def compute_mi_precision_score(scores_targets_tr, scores_targets_te):
    mi_indicator = torch.cat(
        [torch.ones_like(scores_targets_tr), torch.zeros_like(scores_targets_te)]
    )
    mi_pred = torch.cat([scores_targets_tr, scores_targets_te])
    return metrics.average_precision_score(mi_indicator, mi_pred)


def compute_mi_auc_score(scores_targets_tr, scores_targets_te):
    mi_indicator = torch.cat(
        [torch.ones_like(scores_targets_tr), torch.zeros_like(scores_targets_te)]
    )
    mi_pred = torch.cat([scores_targets_tr, scores_targets_te])
    fpr, tpr, thresholds = metrics.roc_curve(mi_indicator, mi_pred)
    return metrics.auc(fpr, tpr)


def compute_mean_conf_pred(softmaxes):
    softmaxes = torch.stack(softmaxes, dim=2)
    softmaxes_mean = torch.mean(softmaxes, dim=2)
    scores_targets, predicted_targets = softmaxes_mean.max(dim=1)
    scores_targets = 1 - scores_targets
    return scores_targets, predicted_targets


def plot_decision_boundary(args, device, X, y, model, steps=1000, cmap="Paired"):
    cmap = plt.get_cmap(cmap)

    X_np = X.cpu().numpy()
    # Define region of interest by data limits
    xmin, xmax = X_np[:, 0].min() - 1, X_np[:, 0].max() + 1
    ymin, ymax = X_np[:, 1].min() - 1, X_np[:, 1].max() + 1
    x_span = np.linspace(xmin, xmax, steps)
    y_span = np.linspace(ymin, ymax, steps)
    xx, yy = np.meshgrid(x_span, y_span)

    # Make predictions across region of interest
    labels = model(torch.Tensor(np.c_[xx.ravel(), yy.ravel()]).to(device))
    labels = F.softmax(labels)
    pred = labels.argmax(dim=1)
    labels[:, 0] = labels[:, 0] - 1
    labels, _ = labels.max(dim=1)

    # Plot decision boundary in region of interest
    z = labels.reshape(xx.shape).detach().cpu().numpy()

    fig, ax = plt.subplots()
    ax.contourf(xx, yy, z, cmap=cmap, alpha=0.33)

    # Get predicted labels on training data and plot
    train_labels = model(X.to(device))
    ax.scatter(X_np[:, 0], X_np[:, 1], c=y.cpu().numpy(), cmap=cmap, lw=0, alpha=0.5)
    # plt.axis('equal')
    plt.tight_layout()
    plt.savefig(
        f"{args.results_path}{args.dataset}_decision_region_eps{args.epsilon}.pdf"
    )
    torch.save(model.state_dict(), f"{args.results_path}{args.dataset}_model.pt")
    torch.save(X, f"{args.results_path}{args.dataset}_X.pt")
    torch.save(y, f"{args.results_path}{args.dataset}_y.pt")
