from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

from dataset import load_data
from plot_reject_rate import fetch_data

save_dir = Path("results/edgeonly")
save_dir_local = Path("results/run1")

# dataset pre-prossessing
Batch_size=16
Batch_size_test=16
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
print(f"Using {device} device")

all_data = {}
for dataset in ["cifar10", "SVHN"]:
    trainset_full, testset_full = load_data(dataset)

    # Create data loaders
    testloader = torch.utils.data.DataLoader(testset_full, batch_size=Batch_size_test, shuffle=False)
    labels = []
    for _, labels_bs in testloader:
        labels.append(labels_bs)

    labels = torch.cat(labels).to(device)

    predicted_local = torch.load(save_dir_local / dataset / f'{dataset}-predicted_local.pth')

    predicted_edge_e = torch.load(save_dir / f'{dataset}-predicted_edge.pth')
    predicted_edge = torch.argmax(predicted_edge_e, dim=1)
    predicted_local = torch.nn.functional.softmax(predicted_local, dim=1)

    predicted_local_rank = torch.argsort(predicted_local, axis=1)
    predicted_local_rank_1 = predicted_local_rank[:, -1]
    predicted_local_rank_2 = predicted_local_rank[:, -2]
    predicted_local_rank_1 = predicted_local[torch.arange(predicted_local.shape[0]), predicted_local_rank_1]
    predicted_local_rank_2 = predicted_local[torch.arange(predicted_local.shape[0]), predicted_local_rank_2]
    rank = torch.argsort(predicted_local_rank_1 - predicted_local_rank_2)

    predicted_local = torch.argmax(predicted_local, dim=1)

    data = []
    for reject_rate in np.linspace(0, 1, 11):
        num_reject = int(reject_rate * len(predicted_edge_e))
        predicted = torch.zeros(len(predicted_edge_e), dtype=torch.long).to(device)
        predicted[rank[:num_reject]] = predicted_edge[rank[:num_reject]]
        predicted[rank[num_reject:]] = predicted_local[rank[num_reject:]]
        acc = sum(predicted == labels) / len(labels)
        data.append((reject_rate, acc.item()))

    all_data[dataset] = data

samples_compare = fetch_data(save_dir_local, ["cifar10", "SVHN"])

plt.figure(figsize=(8, 3))
for i, (dataset, data) in enumerate(all_data.items()):
    plt.subplot(1, 2, i + 1)

    samples = samples_compare[dataset]["$c_1=1.25$"]
    samples = sorted(samples, key=lambda x: x[0])
    x, y, cost = zip(*samples)
    fit = np.polyfit(x, y, 3)
    fit_fn = np.poly1d(fit)
    plt.plot(x, fit_fn(x), "-", c="red", label="Ours")

    plt.plot(*zip(*data), '--', color='blue', label="Confidence-based")

    plt.legend()
    plt.xlabel('Reject rate')
    plt.ylabel('Accuracy')
    plt.title(f"{dataset}")

plt.tight_layout()
plt.savefig("results/edgeonly/reject_rate.pdf")
