import torch
from loader import Box
from models.unet_model import UNet
import cfg
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
opt = cfg.get_arguments().parse_args()
box = Box(opt)

param1, param2, cls_model = box.get_state_dict()
test_tf = box.get_transform(train="test")
cln_testset = CIFAR10("./datasets", train=False, transform=test_tf, download=True)
cln_testloader = DataLoader(cln_testset, batch_size=opt.batch_size, shuffle=False)

cls_model.eval()
total_ba = 0
total_asr = 0
correct_ba = 0
correct_asr = 0
pbar = tqdm(cln_testloader, desc="Test Poisoned Samples")
for cln_imgs, labels in pbar:
    cln_imgs, labels = cln_imgs.to(box.device), labels.to(box.device)
    poi_imgs = box.poisoned(cln_imgs, param1, param2)
    cln_outputs = cls_model(cln_imgs)
    poi_outputs = cls_model(poi_imgs)

    _, cln_pred = cln_outputs.max(1)
    _, poi_pred = poi_outputs.max(1)

    for i in range(cln_imgs.shape[0]):
        total_ba += 1
        if cln_pred[i] == labels[i]:
            correct_ba += 1
        if labels[i] != box.tlabel:
            total_asr += 1
            if poi_pred[i] == box.tlabel:
                correct_asr += 1

    ba = 100. * correct_ba / total_ba
    asr = 100. * correct_asr / total_asr
    
    pbar.set_postfix({"BA": "{:.2f}".format(ba), "ASR": "{:.2f}".format(asr)})

cls_model.eval()
total = 0
correct_asr = 0
inv_generator = UNet(n_channels=3, num_classes=3, base_filter_num=32, num_blocks=4)
inv_generator.load_state_dict(torch.load("./inv_generator/inv_cifar10_ia_t0.pt", map_location="cpu"))
inv_generator.to(box.device)
inv_generator.eval()
pbar = tqdm(cln_testloader, desc="Test BTI-DBF")
for cln_imgs, labels in pbar:
    cln_imgs, labels = cln_imgs.to(box.device), labels.to(box.device)
    inv_imgs = inv_generator(cln_imgs)
    inv_outputs = cls_model(inv_imgs)

    _, inv_pred = inv_outputs.max(1)

    for i in range(cln_imgs.shape[0]):
        total += 1
        if inv_pred[i] == box.tlabel:
            correct_asr += 1

    asr = 100. * correct_asr / total
    
    pbar.set_postfix({"ASR": "{:.2f}".format(asr)})


from copy import deepcopy
total_ba = 0
total_asr = 0
correct_ba = 0
correct_asr = 0
pbar = tqdm(cln_testloader, desc="Test BTI-DBF (U)")
unlearn_model = deepcopy(cls_model)
unlearn_model.load_state_dict(torch.load("ul_model/unlearn_model.pt", map_location="cpu"))
unlearn_model = unlearn_model.to(box.device)
unlearn_model.eval()
for cln_imgs, labels in pbar:
    cln_imgs, labels = cln_imgs.to(box.device), labels.to(box.device)
    poi_imgs = box.poisoned(cln_imgs, param1, param2)
    cln_outputs = unlearn_model(cln_imgs)
    poi_outputs = unlearn_model(poi_imgs)

    _, cln_pred = cln_outputs.max(1)
    _, poi_pred = poi_outputs.max(1)

    for i in range(cln_imgs.shape[0]):
        total_ba += 1
        if cln_pred[i] == labels[i]:
            correct_ba += 1
        if labels[i] != box.tlabel:
            total_asr += 1
            if poi_pred[i] == box.tlabel:
                correct_asr += 1

    ba = 100. * correct_ba / total_ba
    asr = 100. * correct_asr / total_asr
    
    pbar.set_postfix({"BA": "{:.2f}".format(ba), "ASR": "{:.2f}".format(asr)})


cls_model.eval()
total_ba = 0
total_asr = 0
correct_ba = 0
correct_asr = 0
pur_generator = UNet(n_channels=3, num_classes=3, base_filter_num=32, num_blocks=4)
pur_generator.load_state_dict(torch.load("pur_generator/pur_cifar10_ia_t0.pt", map_location="cpu"))
pur_generator.to(box.device)
pur_generator.eval()
pbar = tqdm(cln_testloader, desc="Test BTI-DBF (P)")
for cln_imgs, labels in pbar:
    cln_imgs, labels = cln_imgs.to(box.device), labels.to(box.device)
    poi_imgs = box.poisoned(cln_imgs, param1, param2)
    cln_pur_imgs = pur_generator(cln_imgs)
    poi_pur_imgs = pur_generator(poi_imgs)
    cln_pur_outputs = cls_model(cln_pur_imgs)
    poi_pur_outputs = cls_model(poi_pur_imgs)

    _, cln_pur_pred = cln_pur_outputs.max(1)
    _, poi_pur_pred = poi_pur_outputs.max(1)

    for i in range(cln_imgs.shape[0]):
        total_ba += 1
        if cln_pur_pred[i] == labels[i]:
            correct_ba += 1
        if labels[i] != box.tlabel:
            total_asr += 1
            if poi_pur_pred[i] == box.tlabel:
                correct_asr += 1

    ba = 100. * correct_ba / total_ba
    asr = 100. * correct_asr / total_asr
    
    pbar.set_postfix({"BA": "{:.2f}".format(ba), "ASR": "{:.2f}".format(asr)})

tp = 0
fp = 0
tn = 0
fn = 0
total_poi = 0
pbar = tqdm(cln_testloader, desc="Test Detection")
cls_model.eval()
pur_generator.eval()
for cln_img, targets in pbar:
    cln_img, targets = cln_img.to(box.device), targets.to(box.device)
    poi_img = box.poisoned(cln_img, param1, param2)
    poi_outputs = cls_model(poi_img)
    cln_outputs = cls_model(cln_img)
    pur_poi_img = pur_generator(poi_img)
    pur_cln_img = pur_generator(cln_img)
    pur_poi_outputs = cls_model(pur_poi_img)
    pur_cln_outputs = cls_model(pur_cln_img)

    _, poi_pred = poi_outputs.max(1)
    _, cln_pred = cln_outputs.max(1)
    _, pur_poi_pred = pur_poi_outputs.max(1)
    _, pur_cln_pred = pur_cln_outputs.max(1)

    for i in range(cln_img.shape[0]):
        if targets[i] == box.tlabel:
            continue
        total_poi += 1
        if poi_pred[i]!=pur_poi_pred[i]:
            tp += 1
        else:
            fn += 1
        if cln_pred[i] != pur_cln_pred[i]:
            fp += 1
        else:
            tn += 1

precision = 100. * tp / (tp + fp)
recall = 100. * tp / (tp + fn)
f1_score = 2 * (precision * recall) / (precision + recall)
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 score: {f1_score}")