#################################################
# Evaluate across iterations
#################################################

import argparse
import torch.nn as nn
import utils
import random
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import resnet
import os


NETWORKS = {
    "resnet18": resnet.resnet18,
    "resnet34": resnet.resnet34,
    "resnet50": resnet.resnet50,
    "resnet101": resnet.resnet101,
    "resnet152": resnet.resnet152,
}

# DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02", "0.01"]
DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02"]
# AUGS = ["baseaug", "contrastaug", "randaug", "autoaug"]
AUGS = ["autoaug"]
ITERS = [0, 4, 9, 14]
OUT_DIR = "to_send"

parser = argparse.ArgumentParser()
parser = utils.add_args(parser)
parser.add_argument("--eps", type=float, default=8 / 255, help="fgsm attack eps")
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

device, _ = utils.setup_device(False)
criterion = torch.nn.CrossEntropyLoss()
metric_meter = utils.AvgMeter()


class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, inp):
        mean = self.mean[None, :, None, None]
        std = self.std[None, :, None, None]
        return inp.sub(mean).div(std)


def eval(loader, model, metric_meter, attack=False):
    metric_meter.reset()
    model.eval()
    for indx, (img, target) in enumerate(loader):
        img, target = img.to(device), target.to(device)

        if attack:
            img.requires_grad = True
            pred = model(img)
            cost = criterion(pred, target)
            grad = torch.autograd.grad(cost, img, retain_graph=False, create_graph=False)[0]
            adv_img = img + args.eps * grad.sign()
            img = torch.clamp(adv_img, min=0, max=1).detach()

        with torch.no_grad():
            pred = model(img)
            loss = criterion(pred, target)

        pred_cls = pred.argmax(dim=1)
        acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

        metrics = {"loss": loss.item(), "acc": acc}
        metric_meter.add(metrics)
        utils.pbar(indx / len(loader), msg=metric_meter.msg())
    utils.pbar(1, msg=metric_meter.msg())


f = open(f"{args.dset}_rob_a_results.txt", "w")
for data_size in DATA_SIZES:
    for aug in AUGS:
        for iter in ITERS:
            ckpt = os.path.join(OUT_DIR, f"sparse_advprop_{data_size}_{aug}", f"best_imp_{iter}.ckpt")
            print(f"Evaluating: {ckpt}")
            ckpt = torch.load(ckpt)
            if args.dset == "cifar10":
                norm = Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                n_cls = 10
            elif args.dset == "cifar100":
                norm = Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                n_cls = 100
            else:
                raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
            model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small", pretrained=False).to(device)
            utils.modify_bn(model)
            setattr(
                model,
                "attacker",
                utils.PGDAttacker(
                    args.attack_n_iter, args.attack_eps, args.attack_step_size, 0.2
                ),
            )
            model = model.to(device)
            if iter:
                model.load_state_dict(ckpt["init"])
                curr_mask = utils.extract_mask(ckpt["model"])
                utils.mask_prune(model, curr_mask)
                print("remaining weight = ", utils.check_sparsity(model))
            model.load_state_dict(ckpt["model"])
            model = nn.Sequential(norm, model)
            model = model.to(device)

            # basic
            transform = transforms.ToTensor()
            dset = datasets.CIFAR10(
                root=args.data_root,
                train=False,
                transform=transform,
                download=True,
            )
            loader = DataLoader(
                dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
            )
            # eval(loader, model, metric_meter, attack=False)
            # metrics = metric_meter.get()
            # print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")

            eval(loader, model, metric_meter, attack=True)
            metrics = metric_meter.get()
            print(
                f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
            )

            f.write(f"{round(metrics['acc'], 4)*100}" + " ")
            print("finished evaluating on ckpt")
            print("---------------------------")
        f.write("\n")
        f.flush()
f.close()

# #################################################
# # Evaluate best aug and best ticket
# #################################################

# import argparse
# import torch.nn as nn
# import utils
# import random
# import numpy as np
# import torch
# from torchvision import transforms, datasets
# from torch.utils.data import DataLoader
# import resnet
# import os


# NETWORKS = {
#     "resnet18": resnet.resnet18,
#     "resnet34": resnet.resnet34,
#     "resnet50": resnet.resnet50,
#     "resnet101": resnet.resnet101,
#     "resnet152": resnet.resnet152,
# }

# CKPTS = [
#     # best augs
#     "sparse_1_autoaug/best_imp_0.ckpt",
#     "sparse_0.5_autoaug/best_imp_0.ckpt",
#     "sparse_0.2_autoaug/best_imp_0.ckpt",
#     "sparse_0.1_autoaug/best_imp_0.ckpt",
#     "sparse_0.02_autoaug/best_imp_0.ckpt",
#     "sparse_0.01_contrastaug/best_imp_0.ckpt",
#     # best winning tickets
#     "sparse_1_autoaug/best_imp_2.ckpt",
#     "sparse_0.5_autoaug/best_imp_3.ckpt",
#     "sparse_0.2_autoaug/best_imp_2.ckpt",
#     "sparse_0.1_autoaug/best_imp_4.ckpt",
#     "sparse_0.02_autoaug/best_imp_15.ckpt",
#     "sparse_0.01_randaug/best_imp_15.ckpt",
# ]
# OUT_DIR = "cifar10_ckpts"

# parser = argparse.ArgumentParser()
# parser = utils.add_args(parser)
# parser.add_argument("--eps", type=float, default=8/255, help="fgsm attack eps")
# args = parser.parse_args()

# random.seed(args.seed)
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
# torch.cuda.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)

# device, _ = utils.setup_device(False)
# criterion = torch.nn.CrossEntropyLoss()
# metric_meter = utils.AvgMeter()

# class Normalize(nn.Module):
#     def __init__(self, mean, std):
#         super(Normalize, self).__init__()
#         if not isinstance(mean, torch.Tensor):
#             mean = torch.tensor(mean)
#         if not isinstance(std, torch.Tensor):
#             std = torch.tensor(std)
#         self.register_buffer("mean", mean)
#         self.register_buffer("std", std)

#     def forward(self, inp):
#         mean = self.mean[None, :, None, None]
#         std = self.std[None, :, None, None]
#         return inp.sub(mean).div(std)

# def eval(loader, model, metric_meter, attack=False):
#     metric_meter.reset()
#     model.eval()
#     for indx, (img, target) in enumerate(loader):
#         img, target = img.to(device), target.to(device)

#         if attack:
#             img.requires_grad = True
#             pred = model(img)
#             cost = criterion(pred, target)
#             grad = torch.autograd.grad(cost, img, retain_graph=False, create_graph=False)[0]
#             adv_img = img + args.eps * grad.sign()
#             img = torch.clamp(adv_img, min=0, max=1).detach()

#         with torch.no_grad():
#             pred = model(img)
#             loss = criterion(pred, target)

#         pred_cls = pred.argmax(dim=1)
#         acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

#         metrics = {"loss": loss.item(), "acc": acc}
#         metric_meter.add(metrics)
#         utils.pbar(indx / len(loader), msg=metric_meter.msg())
#     utils.pbar(1, msg=metric_meter.msg())


# f = open(f"{args.dset}_rob_a_results.txt", "w")
# for ckpt_f in CKPTS:
#     ckpt = os.path.join(OUT_DIR, ckpt_f)
#     print(f"Evaluating: {ckpt}")
#     ckpt = torch.load(ckpt)
#     if args.dset == "cifar10":
#         norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
#         n_cls = 10
#     elif args.dset == "cifar100":
#         norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
#         n_cls = 100
#     else:
#         raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
#     model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small").to(device)
#     if int(os.path.basename(ckpt_f).split(".")[0].split("_")[-1]):
#         model.load_state_dict(ckpt["init"])
#         curr_mask = utils.extract_mask(ckpt["model"])
#         utils.mask_prune(model, curr_mask)
#         print("remaining weight = ", utils.check_sparsity(model))
#     model.load_state_dict(ckpt["model"])
#     model = nn.Sequential(norm, model)
#     model = model.to(device)

#     # basic
#     transform = transforms.ToTensor()
#     dset = datasets.CIFAR10(
#         root=args.data_root,
#         train=False,
#         transform=transform,
#         download=True,
#     )
#     loader = DataLoader(
#         dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
#     )
#     # eval(loader, model, metric_meter, attack=False)
#     # metrics = metric_meter.get()
#     # print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")

#     eval(loader, model, metric_meter, attack=True)
#     metrics = metric_meter.get()
#     print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")

#     f.write(f"{round(metrics['acc'], 4)*100}" + "\n")
#     f.flush()

#     print("finished evaluating on ckpt")
#     print("---------------------------")
# f.close()
