# ################################################
# Evaluate across iterations
# ################################################

# import argparse
# import utils
# import random
# import numpy as np
# import torch
# from torchvision import transforms, datasets
# from torch.utils.data import DataLoader, Dataset
# import resnet
# import os

# NETWORKS = {
#     "resnet18": resnet.resnet18,
#     "resnet34": resnet.resnet34,
#     "resnet50": resnet.resnet50,
#     "resnet101": resnet.resnet101,
#     "resnet152": resnet.resnet152,
# }

# DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02", "0.01"]
# AUGS = ["baseaug", "contrastaug", "randaug", "autoaug"]
# ITERS = [0, 4, 9, 14]
# OUT_DIR = "cifar10_ckpts"

# parser = argparse.ArgumentParser()
# parser = utils.add_args(parser)
# parser.add_argument("--rob_data_root", type=str, required=True, help="path to transformed data directory")
# args = parser.parse_args()

# random.seed(args.seed)
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
# torch.cuda.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)

# device, _ = utils.setup_device(False)
# criterion = torch.nn.CrossEntropyLoss()
# metric_meter = utils.AvgMeter()

# @torch.no_grad()
# def eval(loader, model, metric_meter):
#     metric_meter.reset()
#     model.eval()
#     for indx, (img, target) in enumerate(loader):
#         img, target = img.to(device), target.to(device)

#         pred = model(img)
#         loss = criterion(pred, target)

#         pred_cls = pred.argmax(dim=1)
#         acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

#         metrics = {"loss": loss.item(), "acc": acc}
#         metric_meter.add(metrics)
#         utils.pbar(indx / len(loader), msg=metric_meter.msg())
#     utils.pbar(1, msg=metric_meter.msg())


# class CIFARRobustness(Dataset):

#     def __init__(self, root, transform):
#         train_imgs, train_labels = np.load(os.path.join(root, "train.npz"))["images"], np.load(os.path.join(root, "train.npz"))["labels"]
#         test_imgs, test_labels = np.load(os.path.join(root, "test.npz"))["images"], np.load(os.path.join(root, "test.npz"))["labels"]
#         self.imgs = np.concatenate([train_imgs, test_imgs], axis=0)
#         self.labels = np.concatenate([train_labels, test_labels], axis=0)
#         self.transform = transform

#     def __getitem__(self, indx):
#         img = self.imgs[indx]
#         label = self.labels[indx]
#         img = self.transform(img)
#         return img, label

#     def __len__(self):
#         return len(self.imgs)

# f = open(f"{args.dset}_rob_d_results.txt", "w")
# for data_size in DATA_SIZES:
#     for aug in AUGS:
#         temp = []
#         for iter in ITERS:
#             ckpt = os.path.join(OUT_DIR, f"sparse_{data_size}_{aug}", f"best_imp_{iter}.ckpt")
#             print(f"Evaluating: {ckpt}")
#             ckpt = torch.load(ckpt)
#             if args.dset == "cifar10":
#                 norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
#                 n_cls = 10
#             elif args.dset == "cifar100":
#                 norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
#                 n_cls = 100
#             else:
#                 raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
#             model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small").to(device)
#             if iter:
#                 model.load_state_dict(ckpt["init"])
#                 curr_mask = utils.extract_mask(ckpt["model"])
#                 utils.mask_prune(model, curr_mask)
#                 print("remaining weight = ", utils.check_sparsity(model))
#             model.load_state_dict(ckpt["model"])

#             # basic
#             transform = transforms.Compose(
#                 [
#                     transforms.ToTensor(),
#                     norm,
#                 ]
#             )
#             dset = datasets.CIFAR10(
#                 root=args.data_root,
#                 train=False,
#                 transform=transform,
#                 download=True,
#             )
#             loader = DataLoader(
#                 dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
#             )
#             eval(loader, model, metric_meter)
#             metrics = metric_meter.get()
#             print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")

#             dset = CIFARRobustness(
#                 root=args.rob_data_root, transform=transform
#             )
#             loader = DataLoader(
#                 dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
#             )
#             eval(loader, model, metric_meter)
#             metrics = metric_meter.get()
#             print(
#                 f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
#             )
#             temp.append(str(round(metrics['acc'], 4)*100))

#             print("finished evaluating on ckpt")
#             print("---------------------------")

#         f.write(" ".join(temp) + "\n")
#         f.flush()
# f.close()

#################################################
# Evaluate best aug and best ticket
#################################################

import argparse
import utils
import random
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import resnet
import os

NETWORKS = {
    "resnet18": resnet.resnet18,
    "resnet34": resnet.resnet34,
    "resnet50": resnet.resnet50,
    "resnet101": resnet.resnet101,
    "resnet152": resnet.resnet152,
}

CKPTS = [
    # # best augs
    # "sparse_1_autoaug/best_imp_0.ckpt",
    # "sparse_0.5_autoaug/best_imp_0.ckpt",
    # "sparse_0.2_autoaug/best_imp_0.ckpt",
    # "sparse_0.1_autoaug/best_imp_0.ckpt",
    # "sparse_0.02_autoaug/best_imp_0.ckpt",
    # "sparse_0.01_contrastaug/best_imp_0.ckpt",
    # # best winning tickets
    # "sparse_1_autoaug/best_imp_2.ckpt",
    # "sparse_0.5_autoaug/best_imp_3.ckpt",
    # "sparse_0.2_autoaug/best_imp_2.ckpt",
    # "sparse_0.1_autoaug/best_imp_4.ckpt",
    # "sparse_0.02_autoaug/best_imp_15.ckpt",
    # "sparse_0.01_randaug/best_imp_15.ckpt",

    # "sparse_1_baseaug/best_imp_2.ckpt",
    # "sparse_0.5_baseaug/best_imp_6.ckpt",
    # "sparse_0.2_baseaug/best_imp_6.ckpt",
    # "sparse_0.1_baseaug/best_imp_15.ckpt",
    # "sparse_0.02_baseaug/best_imp_14.ckpt",
    # "sparse_0.01_baseaug/best_imp_15.ckpt",

    "sparse_1_contrastaug/best_imp_3.ckpt",
    "sparse_0.5_contrastaug/best_imp_3.ckpt",
    "sparse_0.2_contrastaug/best_imp_2.ckpt",
    "sparse_0.1_contrastaug/best_imp_12.ckpt",
    "sparse_0.02_contrastaug/best_imp_15.ckpt",
    "sparse_0.01_contrastaug/best_imp_15.ckpt",

    "sparse_1_randaug/best_imp_1.ckpt",
    "sparse_0.5_randaug/best_imp_1.ckpt",
    "sparse_0.2_randaug/best_imp_1.ckpt",
    "sparse_0.1_randaug/best_imp_10.ckpt",
    "sparse_0.02_randaug/best_imp_13.ckpt",
    "sparse_0.01_randaug/best_imp_15.ckpt",

    "sparse_1_autoaug/best_imp_2.ckpt",
    "sparse_0.5_autoaug/best_imp_1.ckpt",
    "sparse_0.2_autoaug/best_imp_2.ckpt",
    "sparse_0.1_autoaug/best_imp_4.ckpt",
    "sparse_0.02_autoaug/best_imp_15.ckpt",
    "sparse_0.01_autoaug/best_imp_15.ckpt",
]
OUT_DIR = "output/cifar10_ckpts"

parser = argparse.ArgumentParser()
parser = utils.add_args(parser)
parser.add_argument(
    "--rob_data_root", type=str, required=True, help="path to transformed data directory"
)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

device, _ = utils.setup_device(False)
criterion = torch.nn.CrossEntropyLoss()
metric_meter = utils.AvgMeter()


@torch.no_grad()
def eval(loader, model, metric_meter):
    metric_meter.reset()
    model.eval()
    for indx, (img, target) in enumerate(loader):
        img, target = img.to(device), target.to(device)

        pred = model(img)
        loss = criterion(pred, target)

        pred_cls = pred.argmax(dim=1)
        acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

        metrics = {"loss": loss.item(), "acc": acc}
        metric_meter.add(metrics)
        utils.pbar(indx / len(loader), msg=metric_meter.msg())
    utils.pbar(1, msg=metric_meter.msg())


class CIFARRobustness(Dataset):
    def __init__(self, root, transform):
        train_imgs, train_labels = (
            np.load(os.path.join(root, "train.npz"))["images"],
            np.load(os.path.join(root, "train.npz"))["labels"],
        )
        test_imgs, test_labels = (
            np.load(os.path.join(root, "test.npz"))["images"],
            np.load(os.path.join(root, "test.npz"))["labels"],
        )
        self.imgs = np.concatenate([train_imgs, test_imgs], axis=0)
        self.labels = np.concatenate([train_labels, test_labels], axis=0)
        self.transform = transform

    def __getitem__(self, indx):
        img = self.imgs[indx]
        label = self.labels[indx]
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)


f = open(f"{args.dset}_rob_d_results.txt", "w")
for ckpt_f in CKPTS:
    ckpt = os.path.join(OUT_DIR, ckpt_f)
    print(f"Evaluating: {ckpt}")
    ckpt = torch.load(ckpt)
    if args.dset == "cifar10":
        norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        n_cls = 10
    elif args.dset == "cifar100":
        norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        n_cls = 100
    else:
        raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
    model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small", pretrained=False).to(device)
    if int(os.path.basename(ckpt_f).split(".")[0].split("_")[-1]):
        model.load_state_dict(ckpt["init"])
        curr_mask = utils.extract_mask(ckpt["model"])
        utils.mask_prune(model, curr_mask)
        print("remaining weight = ", utils.check_sparsity(model))
    model.load_state_dict(ckpt["model"])

    # basic
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            norm,
        ]
    )
    dset = datasets.CIFAR10(
        root=args.data_root,
        train=False,
        transform=transform,
        download=True,
    )
    loader = DataLoader(
        dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
    )
    eval(loader, model, metric_meter)
    metrics = metric_meter.get()
    print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")

    dset = CIFARRobustness(root=args.rob_data_root, transform=transform)
    loader = DataLoader(
        dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
    )
    eval(loader, model, metric_meter)
    metrics = metric_meter.get()
    print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")
    f.write(str(round(metrics["acc"], 4) * 100) + "\n")
    f.flush()
    print("finished evaluating on ckpt")
    print("---------------------------")
f.close()
