"""
    setup model and datasets
"""


import copy
import os
import random

# from advertorch.utils import NormalizeByChannelMeanStd
import shutil
import sys
import time

import numpy as np
import torch
from dataset import *
from dataset import TinyImageNet
from imagenet import prepare_data
from models import *
from torchvision import transforms
import torch.nn as nn

__all__ = [
    "setup_model_dataset",
    "setup_model",
    "AverageMeter",
    "warmup_lr",
    "save_checkpoint",
    "setup_seed",
    "accuracy",
]


def warmup_lr(epoch, step, optimizer, one_epoch_step, args):
    overall_steps = args.warmup * one_epoch_step
    current_steps = epoch * one_epoch_step + step

    lr = args.lr * current_steps / overall_steps
    lr = min(lr, args.lr)

    for p in optimizer.param_groups:
        p["lr"] = lr


def save_checkpoint(
    state, is_SA_best, save_path, pruning, filename="checkpoint.pth.tar"
):
    filepath = os.path.join(save_path, str(pruning) + filename)
    torch.save(state, filepath)
    if is_SA_best:
        shutil.copyfile(
            filepath, os.path.join(save_path, str(pruning) + "model_SA_best.pth.tar")
        )


def load_checkpoint(device, save_path, pruning, filename="checkpoint.pth.tar"):
    filepath = os.path.join(save_path, str(pruning) + filename)
    if os.path.exists(filepath):
        print("Load checkpoint from:{}".format(filepath))
        return torch.load(filepath, device)
    print("Checkpoint not found! path:{}".format(filepath))
    return None


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def dataset_convert_to_train(dataset):
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    while hasattr(dataset, "dataset"):
        dataset = dataset.dataset
    dataset.transform = train_transform
    dataset.train = False


def dataset_convert_to_test(dataset, args=None):
    if args.dataset == "TinyImagenet":
        test_transform = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC)
        ])
    else:
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
    while hasattr(dataset, "dataset"):
        dataset = dataset.dataset
    dataset.transform = test_transform
    dataset.train = False


def setup_model_dataset(args, rank=None):
    if args.dataset == "cifar10":
        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
        )
        train_full_loader, val_loader, _ = cifar10_dataloaders(
            batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
        )
        marked_loader, _, test_loader = cifar10_dataloaders(
            batch_size=args.batch_size,
            data_dir=args.data,
            num_workers=args.workers,
            class_to_replace=args.class_to_replace,
            num_indexes_to_replace=args.num_indexes_to_replace,
            indexes_to_replace=args.indexes_to_replace,
            seed=args.seed,
            only_mark=True,
            shuffle=True,
            no_aug=args.no_aug,
        )

        if args.train_seed is None:
            args.train_seed = args.seed
        setup_seed(args.train_seed)

        if 'lora' in args.arch:
            assert rank is not None
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
            else:
                model = model_dict[args.arch](num_classes=classes, rank=rank)
        else:
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True)
            else:
                model = model_dict[args.arch](num_classes=classes)

        setup_seed(args.train_seed)

        model.normalize = normalization
        return model, train_full_loader, val_loader, test_loader, marked_loader
    elif args.dataset == "svhn":
        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4377, 0.4438, 0.4728], std=[0.1201, 0.1231, 0.1052]
        )
        train_full_loader, val_loader, _ = svhn_dataloaders(
            batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
        )
        marked_loader, _, test_loader = svhn_dataloaders(
            batch_size=args.batch_size,
            data_dir=args.data,
            num_workers=args.workers,
            class_to_replace=args.class_to_replace,
            num_indexes_to_replace=args.num_indexes_to_replace,
            indexes_to_replace=args.indexes_to_replace,
            seed=args.seed,
            only_mark=True,
            shuffle=True,
        )
        if 'lora' in args.arch:
            assert rank is not None
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
            else:
                model = model_dict[args.arch](num_classes=classes, rank=rank)
        else:
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True)
            else:
                model = model_dict[args.arch](num_classes=classes)

        model.normalize = normalization
        return model, train_full_loader, val_loader, test_loader, marked_loader
    elif args.dataset == "cifar100":
        classes = 100
        normalization = NormalizeByChannelMeanStd(
            mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
        )
        train_full_loader, val_loader, _ = cifar100_dataloaders(
            batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
        )
        marked_loader, _, test_loader = cifar100_dataloaders(
            batch_size=args.batch_size,
            data_dir=args.data,
            num_workers=args.workers,
            class_to_replace=args.class_to_replace,
            num_indexes_to_replace=args.num_indexes_to_replace,
            indexes_to_replace=args.indexes_to_replace,
            seed=args.seed,
            only_mark=True,
            shuffle=True,
            no_aug=args.no_aug,
        )
        if 'lora' in args.arch:
            assert rank is not None
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
            else:
                model = model_dict[args.arch](num_classes=classes, rank=rank)
        else:
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True)
            else:
                model = model_dict[args.arch](num_classes=classes)
        model.normalize = normalization
        return model, train_full_loader, val_loader, test_loader, marked_loader
    elif args.dataset == "TinyImagenet":
        classes = 200
        normalization = NormalizeByChannelMeanStd(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        train_full_loader, val_loader, test_loader = TinyImageNet(args).data_loaders(
            batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
        )
        # train_full_loader, val_loader, test_loader =None, None,None
        marked_loader, _, _ = TinyImageNet(args).data_loaders(
            batch_size=args.batch_size,
            data_dir=args.data,
            num_workers=args.workers,
            class_to_replace=args.class_to_replace,
            num_indexes_to_replace=args.num_indexes_to_replace,
            indexes_to_replace=args.indexes_to_replace,
            seed=args.seed,
            only_mark=True,
            shuffle=True,
        )
        if 'lora' in args.arch:
            assert rank is not None
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
            else:
                model = model_dict[args.arch](num_classes=classes, rank=rank)
        else:
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True)
            else:
                model = model_dict[args.arch](num_classes=classes)

        model.normalize = normalization
        return model, train_full_loader, val_loader, test_loader, marked_loader

    elif args.dataset == "imagenet":
        classes = 1000
        normalization = NormalizeByChannelMeanStd(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        train_ys = torch.load(args.train_y_file)
        val_ys = torch.load(args.val_y_file)
        if 'lora' in args.arch:
            assert rank is not None
            model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
        else:
            model = model_dict[args.arch](num_classes=classes, imagenet=True)

        model.normalize = normalization
        if args.class_to_replace is None:
            loaders = prepare_data(dataset="imagenet", batch_size=args.batch_size)
            train_loader, val_loader = loaders["train"], loaders["val"]
            return model, train_loader, val_loader
        else:
            train_subset_indices = torch.ones_like(train_ys)
            val_subset_indices = torch.ones_like(val_ys)
            train_subset_indices[train_ys == args.class_to_replace] = 0
            val_subset_indices[val_ys == args.class_to_replace] = 0
            loaders = prepare_data(
                dataset="imagenet",
                batch_size=args.batch_size,
                train_subset_indices=train_subset_indices,
                val_subset_indices=val_subset_indices,
            )
            retain_loader = loaders["train"]
            forget_loader = loaders["fog"]
            val_loader = loaders["val"]
            return model, retain_loader, forget_loader, val_loader

    elif args.dataset == "cifar100_no_val":
        classes = 100
        normalization = NormalizeByChannelMeanStd(
            mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
        )
        train_set_loader, val_loader, test_loader = cifar100_dataloaders_no_val(
            batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
        )

    elif args.dataset == "cifar10_no_val":
        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
        )
        train_set_loader, val_loader, test_loader = cifar10_dataloaders_no_val(
            batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
        )

    else:
        raise ValueError("Dataset not supprot yet !")
    # import pdb;pdb.set_trace()

    if 'lora' in args.arch:
        assert rank is not None
        if args.imagenet_arch:
            model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
        else:
            model = model_dict[args.arch](num_classes=classes, rank=rank)
    else:
        if args.imagenet_arch:
            model = model_dict[args.arch](num_classes=classes, imagenet=True)
        else:
            model = model_dict[args.arch](num_classes=classes)

    model.normalize = normalization
    return model, train_set_loader, val_loader, test_loader



def setup_model(args, rank=None):
    if args.dataset == "cifar10":
        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
        )
        if args.train_seed is None:
            args.train_seed = args.seed
        setup_seed(args.train_seed)

        if 'lora' in args.arch:
            assert rank is not None
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
            else:
                model = model_dict[args.arch](num_classes=classes, rank=rank)
        else:
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True)
            else:
                model = model_dict[args.arch](num_classes=classes)

        setup_seed(args.train_seed)

        model.normalize = normalization
        return model
    elif args.dataset == "svhn":
        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4377, 0.4438, 0.4728], std=[0.1201, 0.1231, 0.1052]
        )
        if 'lora' in args.arch:
            assert rank is not None
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
            else:
                model = model_dict[args.arch](num_classes=classes, rank=rank)
        else:
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True)
            else:
                model = model_dict[args.arch](num_classes=classes)

        model.normalize = normalization
        return model
    elif args.dataset == "cifar100":
        classes = 100
        normalization = NormalizeByChannelMeanStd(
            mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
        )
        if 'lora' in args.arch:
            assert rank is not None
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
            else:
                model = model_dict[args.arch](num_classes=classes, rank=rank)
        else:
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True)
            else:
                model = model_dict[args.arch](num_classes=classes)
        model.normalize = normalization
        return model
    elif args.dataset == "TinyImagenet":
        classes = 200
        normalization = NormalizeByChannelMeanStd(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        if 'lora' in args.arch:
            assert rank is not None
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
            else:
                model = model_dict[args.arch](num_classes=classes, rank=rank)
        else:
            if args.imagenet_arch:
                model = model_dict[args.arch](num_classes=classes, imagenet=True)
            else:
                model = model_dict[args.arch](num_classes=classes)

        model.normalize = normalization
        return model

    elif args.dataset == "imagenet":
        classes = 1000
        normalization = NormalizeByChannelMeanStd(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        if 'lora' in args.arch:
            assert rank is not None
            model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
        else:
            model = model_dict[args.arch](num_classes=classes, imagenet=True)

        model.normalize = normalization
        return model

    elif args.dataset == "cifar100_no_val":
        classes = 100
        normalization = NormalizeByChannelMeanStd(
            mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
        )

    elif args.dataset == "cifar10_no_val":
        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
        )

    else:
        raise ValueError("Dataset not supprot yet !")
    # import pdb;pdb.set_trace()

    if 'lora' in args.arch:
        assert rank is not None
        if args.imagenet_arch:
            model = model_dict[args.arch](num_classes=classes, imagenet=True, rank=rank)
        else:
            model = model_dict[args.arch](num_classes=classes, rank=rank)
    else:
        if args.imagenet_arch:
            model = model_dict[args.arch](num_classes=classes, imagenet=True)
        else:
            model = model_dict[args.arch](num_classes=classes)

    model.normalize = normalization
    return model

def setup_seed(seed):
    print("setup random seed = {}".format(seed))
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


class NormalizeByChannelMeanStd(torch.nn.Module):
    def __init__(self, mean, std):
        super(NormalizeByChannelMeanStd, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, tensor):
        return self.normalize_fn(tensor, self.mean, self.std)

    def extra_repr(self):
        return "mean={}, std={}".format(self.mean, self.std)

    def normalize_fn(self, tensor, mean, std):
        """Differentiable version of torchvision.functional.normalize"""
        # here we assume the color channel is in at dim=1
        mean = mean[None, :, None, None]
        std = std[None, :, None, None]
        return tensor.sub(mean).div(std)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def run_commands(gpus, commands, call=False, dir="commands", shuffle=True, delay=0.5):
    if len(commands) == 0:
        return
    if os.path.exists(dir):
        shutil.rmtree(dir)
    if shuffle:
        random.shuffle(commands)
        random.shuffle(gpus)
    os.makedirs(dir, exist_ok=True)

    fout = open("stop_{}.sh".format(dir), "w")
    print("kill $(ps aux|grep 'bash " + dir + "'|awk '{print $2}')", file=fout)
    fout.close()

    n_gpu = len(gpus)
    for i, gpu in enumerate(gpus):
        i_commands = commands[i::n_gpu]
        if len(i_commands) == 0:
            continue
        prefix = "CUDA_VISIBLE_DEVICES={} ".format(gpu)

        sh_path = os.path.join(dir, "run{}.sh".format(i))
        fout = open(sh_path, "w")
        for com in i_commands:
            print(prefix + com, file=fout)
        fout.close()
        if call:
            os.system("bash {}&".format(sh_path))
            time.sleep(delay)


def get_loader_from_dataset(dataset, batch_size, seed=1, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=shuffle
    )


def get_unlearn_loader(marked_loader, args):
    forget_dataset = copy.deepcopy(marked_loader.dataset)
    marked = forget_dataset.targets < 0
    forget_dataset.data = forget_dataset.data[marked]
    forget_dataset.targets = -forget_dataset.targets[marked] - 1
    forget_loader = get_loader_from_dataset(
        forget_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=True
    )
    retain_dataset = copy.deepcopy(marked_loader.dataset)
    marked = retain_dataset.targets >= 0
    retain_dataset.data = retain_dataset.data[marked]
    retain_dataset.targets = retain_dataset.targets[marked]
    retain_loader = get_loader_from_dataset(
        retain_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=True
    )
    print("datasets length: ", len(forget_dataset), len(retain_dataset))
    return forget_loader, retain_loader


def get_poisoned_loader(poison_loader, unpoison_loader, test_loader, poison_func, args):
    poison_dataset = copy.deepcopy(poison_loader.dataset)
    poison_test_dataset = copy.deepcopy(test_loader.dataset)

    poison_dataset.data, poison_dataset.targets = poison_func(
        poison_dataset.data, poison_dataset.targets
    )
    poison_test_dataset.data, poison_test_dataset.targets = poison_func(
        poison_test_dataset.data, poison_test_dataset.targets
    )

    full_dataset = torch.utils.data.ConcatDataset(
        [unpoison_loader.dataset, poison_dataset]
    )

    poisoned_loader = get_loader_from_dataset(
        poison_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=False
    )
    poisoned_full_loader = get_loader_from_dataset(
        full_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=True
    )
    poisoned_test_loader = get_loader_from_dataset(
        poison_test_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=False
    )

    return poisoned_loader, unpoison_loader, poisoned_full_loader, poisoned_test_loader


def mark_bn_as_trainable(model: nn.Module) -> None:
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d) and m.affine:
            m.weight.requires_grad = True
            m.bias.requires_grad = True


def freeze_loraB(model: nn.Module) -> None:
    for k, v in model.named_parameters():
        if "lora_B" in k:
            v.requires_grad = False


def unfreeze_loraB(model: nn.Module) -> None:
    for k, v in model.named_parameters():
        if "lora_B" in k:
            v.requires_grad = True