import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve


def load_data(dataset_name, path="../datasets/benchmark_atk/data/"):
    ins = torch.load(path+"ins_%s.pt"%(dataset_name))
    label = torch.load(path+"lable_%s.pt"%(dataset_name))
    return ins, label


def load_curve(criterion, aggregate_loss_type, dataset_name, loss_type):
    return np.loadtxt("./%s/%s_%s_%s.csv"%(criterion, loss_type, aggregate_loss_type, dataset_name))


def show_perf(criterion, dataset_name, loss_type="hinge"):
    if criterion == "precision_recall_curve":
        fig, ax = plt.subplots()
        for aggregate_loss_type in ["average", "atk", "atk_gd_smoothing"]:
            X, Y = load_data(dataset_name=dataset_name); Y = Y.numpy()
            A2 = np.loadtxt("./A2/%s_%s_%s.csv"%(loss_type, aggregate_loss_type, dataset_name)).reshape(-1, 1)
            precision, recall, _ = precision_recall_curve(Y, A2)
            ax.plot(recall, precision, label=aggregate_loss_type, linewidth=0.6)
        ax.set_xlabel("recall")
        ax.set_ylabel("precision")
        ax.legend()
        plt.savefig("./%s_%s_%s.png"%(criterion, loss_type, dataset_name))
    else:
        fig, ax = plt.subplots()
        for aggregate_loss_type in ["average", "atk", "atk_gd_smoothing"]:
            y = load_curve(criterion=criterion, aggregate_loss_type=aggregate_loss_type, dataset_name=dataset_name, loss_type=loss_type)
            x = np.arange(1, len(y)+1)
            ax.plot(x, y, label=aggregate_loss_type, linewidth=0.6)
        ax.set_xlabel("iterations")
        ax.set_ylabel(criterion)
        ax.legend()
        plt.savefig("./%s_%s_%s.png"%(criterion, loss_type, dataset_name))


for criterion in ["precision_recall_curve"]:#"loss_curve", "acc_score_curve", 
    for dataset_name in ["appendicitis", "phoneme", "wisconsin", "australian", "german", "titanic", "spambase", "segment0"]:#
        show_perf(criterion=criterion, dataset_name=dataset_name)

