import matplotlib.pyplot as plt
import pandas as pd
import torch

import wandb


def plot_cov_acc_tradeoff(covs_at_fa, accs_at_fa, covs_at_fc, accs_at_fc, run):
    data = [[x, y] for (x, y) in zip(accs_at_fa, covs_at_fa)]
    table = wandb.Table(data=data, columns=["Accuracy", "Coverage"])
    run.log(
        {
            "cov_acc_fa": wandb.plot.line(
                table,
                "Accuracy",
                "Coverage",
                title="Coverage-Accuracy Tradeoff (Fixed Accuracy)",
            )
        }
    )
    data = [[x, y] for (x, y) in zip(covs_at_fc, accs_at_fc)]
    table = wandb.Table(data=data, columns=["Coverage", "Accuracy"])
    run.log(
        {
            "cov_acc_fc": wandb.plot.line(
                table,
                "Coverage",
                "Accuracy",
                title="Coverage-Accuracy Tradeoff (Fixed Coverage)",
            )
        }
    )


def plot_e_t_metric(e_t_corr, e_t_incorr, run):
    run.log(
        {
            "e_t": wandb.plot.line_series(
                xs=[torch.arange(len(e_t_corr)), torch.arange(len(e_t_incorr))],
                ys=[e_t_corr, e_t_incorr],
                keys=["corr", "incorr"],
                title="e_t",
                xname="Checkpoints",
            )
        }
    )


def plot_v_t_metric(v_t_corr, v_t_incorr, run):
    run.log(
        {
            "v_t": wandb.plot.line_series(
                xs=[torch.arange(len(v_t_corr)), torch.arange(len(v_t_incorr))],
                ys=[v_t_corr, v_t_incorr],
                keys=["corr", "incorr"],
                title="v_t",
                xname="Checkpoints",
            )
        }
    )


def plot_score_dist(args, scores_targets, run):
    df = pd.DataFrame({"scores": scores_targets})
    hist = df.hist()
    plt.savefig(f"{args.results_path}hist_pd.png")
    plt.clf()
    table = wandb.Table(dataframe=df)
    run.log({"scores_target": wandb.plot.histogram(table, "scores", title="Histogram")})


def plot_roc_confusion_matrix(true_targets, predicted_targets, run):
    print(true_targets)
    print(predicted_targets)
    cm = wandb.plot.confusion_matrix(y_true=true_targets, preds=predicted_targets)

    run.log({"conf_mat": cm})


def plot_sc_mi_precision_performance(sc_performance, mi_performance, run):
    data = [[sc_performance, mi_performance]]
    table = wandb.Table(data=data, columns=["sc", "mi"])
    run.log(
        {
            "sc_mi_prec": wandb.plot.scatter(
                table,
                "sc",
                "mi",
                "MI (Precision) - SC Tradeoff",
            )
        }
    )


def plot_sc_mi_auc_performance(sc_performance, mi_performance, run):
    data = [[sc_performance, mi_performance]]
    table = wandb.Table(data=data, columns=["sc", "mi"])
    run.log(
        {
            "sc_mi_auc": wandb.plot.scatter(
                table,
                "sc",
                "mi",
                "MI (AUC) - SC Tradeoff",
            )
        }
    )
