"""
    setup model and datasets
"""


import copy
import os
import random

import shutil
import sys
import time

import numpy as np
import torch
from torchvision import transforms

from dataset import *
from dataset import TinyImageNet
from delta_imagenet import prepare_data
from models import *
from timm import create_model
import timm
__all__ = [
    "setup_model_dataset",
    "AverageMeter",
    "warmup_lr",
    "save_checkpoint",
    "setup_seed",
    "accuracy",
]


def warmup_lr(epoch, step, optimizer, one_epoch_step, args):
    overall_steps = args.warmup * one_epoch_step
    current_steps = epoch * one_epoch_step + step

    lr = args.lr * current_steps / overall_steps
    lr = min(lr, args.lr)

    for p in optimizer.param_groups:
        p["lr"] = lr

def save_checkpoint(
    state, is_SA_best, save_path, pruning, epoch,filename="checkpoint.pth.tar"
):
    filepath = os.path.join(save_path, str(pruning) +str(epoch)+"_"+ filename)
    torch.save(state, filepath)
    if is_SA_best:
        shutil.copyfile(
            filepath, os.path.join(save_path, str(pruning) + "model_SA_best.pth.tar")
        )

def load_checkpoint(device, save_path, pruning, filename="checkpoint.pth.tar"):
    filepath = os.path.join(save_path, str(pruning) + filename)
    if os.path.exists(filepath):
        return torch.load(filepath, device)
    return None


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def dataset_convert_to_train(dataset):
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    while hasattr(dataset, "dataset"):
        dataset = dataset.dataset
    dataset.transform = train_transform
    dataset.train = False


def dataset_convert_to_test(dataset, args=None):
    if args.dataset == "TinyImagenet":
        test_transform = transforms.Compose([])
    else:
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
    while hasattr(dataset, "dataset"):
        dataset = dataset.dataset
    dataset.transform = test_transform
    dataset.train = False


def setup_model_dataset(args):
    if args.dataset == "cifar10":
        setup_seed(args.train_seed)

        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
        )
        model = None
        if args.arch == "resnet18":
            model = model_dict[args.arch](num_classes=classes,imagenet=False)
        elif args.arch == "vgg16_bn":
            model = model_dict[args.arch](num_classes=classes)

        model.normalize = normalization
        
        if args.class_to_replace is None and args.num_indexes_to_replace is None:

            sets = prepare_data(dataset="cifar10", batch_size=args.batch_size, class_to_replace = args.class_to_replace, indexes_to_replace = args.indexes_to_replace, 
                                   seed=args.train_seed, single = args.single, adv=args.adv, cor=args.cor, cor_type=args.cor_type, level=args.level, 
                                   phase=args.phase, data_path=args.data, arch=args.arch, percent = args.percent)
            train_set, train_set_for_test, val_set = sets["train"], sets["train_for_test"], sets["val"]
            return model, train_set, train_set_for_test, val_set
        
        elif args.class_to_replace is not None and args.num_indexes_to_replace is None:#unlearning class-wise

            sets = prepare_data(dataset="cifar10", batch_size=args.batch_size,class_to_replace = args.class_to_replace, indexes_to_replace = args.indexes_to_replace, adv =args.adv,cor = args.cor,cor_type = args.cor_type,level = args.level,phase = args.phase,data_path = args.data)            
            retain_set, forget_set, retain_set_for_test, forget_set_for_test, val_set, val_set_retain, val_set_forget, retain_set_adv, forget_set_adv, val_set_adv, val_retain_set_adv, val_forget_set_adv = sets["retain"], sets["forget"], sets["retain_for_test"], sets["forget_for_test"], sets['val'], sets['val_retain'], sets['val_forget'], sets["retain_adv"], sets["forget_adv"], sets["val_adv"], sets['val_retain_adv'],sets['val_forget_adv']
            return model,retain_set, forget_set, retain_set_for_test, forget_set_for_test, val_set, val_set_retain, val_set_forget, retain_set_adv, forget_set_adv, val_set_adv, val_retain_set_adv, val_forget_set_adv 
        
        elif args.class_to_replace is None and args.num_indexes_to_replace is not None:#unlearning class-wise
            sets = prepare_data(dataset="cifar10", batch_size=args.batch_size,class_to_replace = args.class_to_replace, indexes_to_replace = args.indexes_to_replace,num_indexes_to_replace=args.num_indexes_to_replace,seed=args.seed,adv =args.adv,cor = args.cor,cor_type = args.cor_type,level = args.level,phase = args.phase,data_path = args.data, percent = args.percent)
            retain_set, forget_set, retain_set_for_test, forget_set_for_test, val_set, retain_set_adv, forget_set_adv, val_set_adv = sets["retain"], sets["forget"], sets['retain_for_test'], sets['forget_for_test'], sets["val"], sets["retain_adv"], sets["forget_adv"], sets["val_adv"]
            return model, retain_set, forget_set, retain_set_for_test, forget_set_for_test, val_set, retain_set_adv, forget_set_adv, val_set_adv
    elif args.dataset == "imagenet10":
        classes = 10
        setup_seed(args.train_seed)

        normalization = NormalizeByChannelMeanStd(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        
        model = None
        if args.arch == "resnet18":
            model = model_dict[args.arch](num_classes=classes,imagenet=True)
        elif args.arch == "vgg16_bn":
            model = model_dict[args.arch](num_classes=classes)
        elif args.arch == "vit":
            model = timm.create_model('vit_base_patch16_224', pretrained=True)
            num_classes = 10  
            model.head = nn.Linear(model.head.in_features, num_classes)

        model.normalize = normalization
        
        if args.class_to_replace is None and args.num_indexes_to_replace is None:
            sets = prepare_data(dataset="imagenet10", batch_size=args.batch_size,class_to_replace = args.class_to_replace, indexes_to_replace = args.indexes_to_replace, single = args.single, phase = args.phase,data_path = args.data)
            train_set, train_set_for_test, val_set = sets["train"], sets["train_for_test"], sets["val"]
            return model, train_set, train_set_for_test, val_set
            
        elif args.class_to_replace is not None and args.num_indexes_to_replace is None:#unlearning class-wise
            sets = prepare_data(dataset="imagenet10", batch_size=args.batch_size,class_to_replace = args.class_to_replace, indexes_to_replace = args.indexes_to_replace, adv =args.adv,cor = args.cor,cor_type = args.cor_type,level = args.level,phase = args.phase,data_path = args.data)
            retain_set, forget_set, retain_set_for_test, forget_set_for_test, val_set, val_set_retain, val_set_forget, retain_set_adv, forget_set_adv, val_set_adv, val_retain_set_adv, val_forget_set_adv = sets["retain"], sets["forget"], sets["retain_for_test"], sets["forget_for_test"], sets['val'], sets['val_retain'], sets['val_forget'], sets["retain_adv"], sets["forget_adv"], sets["val_adv"], sets['val_retain_adv'],sets['val_forget_adv']
            return model,retain_set, forget_set, retain_set_for_test, forget_set_for_test, val_set, val_set_retain, val_set_forget, retain_set_adv, forget_set_adv, val_set_adv, val_retain_set_adv, val_forget_set_adv 
        
        elif args.class_to_replace is None and args.num_indexes_to_replace is not None:#unlearning class-wise
            sets = prepare_data(dataset="imagenet10", batch_size=args.batch_size,class_to_replace = args.class_to_replace, indexes_to_replace = args.indexes_to_replace,num_indexes_to_replace=args.num_indexes_to_replace,seed=args.seed,adv =args.adv,cor = args.cor,cor_type = args.cor_type,level = args.level,phase = args.phase,data_path = args.data)
            retain_set, forget_set, retain_set_for_test, forget_set_for_test, val_set, retain_set_adv, forget_set_adv, val_set_adv = sets["retain"], sets["forget"], sets['retain_for_test'], sets['forget_for_test'], sets["val"], sets["retain_adv"], sets["forget_adv"], sets["val_adv"]
            return model, retain_set, forget_set, retain_set_for_test, forget_set_for_test, val_set, retain_set_adv, forget_set_adv, val_set_adv                       
    else:
        raise ValueError("Dataset not supprot yet !")


def setup_seed(seed):
    print("setup random seed = {}".format(seed))
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


class NormalizeByChannelMeanStd(torch.nn.Module):
    def __init__(self, mean, std):
        super(NormalizeByChannelMeanStd, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, tensor):
        return self.normalize_fn(tensor, self.mean, self.std)

    def extra_repr(self):
        return "mean={}, std={}".format(self.mean, self.std)

    def normalize_fn(self, tensor, mean, std):
        """Differentiable version of torchvision.functional.normalize"""
        # here we assume the color channel is in at dim=1
        mean = mean[None, :, None, None]
        std = std[None, :, None, None]
        return tensor.sub(mean).div(std)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def run_commands(gpus, commands, call=False, dir="commands", shuffle=True, delay=0.5):
    if len(commands) == 0:
        return
    if os.path.exists(dir):
        shutil.rmtree(dir)
    if shuffle:
        random.shuffle(commands)
        random.shuffle(gpus)
    os.makedirs(dir, exist_ok=True)

    fout = open("stop_{}.sh".format(dir), "w")
    print("kill $(ps aux|grep 'bash " + dir + "'|awk '{print $2}')", file=fout)
    fout.close()

    n_gpu = len(gpus)
    for i, gpu in enumerate(gpus):
        i_commands = commands[i::n_gpu]
        if len(i_commands) == 0:
            continue
        prefix = "CUDA_VISIBLE_DEVICES={} ".format(gpu)

        sh_path = os.path.join(dir, "run{}.sh".format(i))
        fout = open(sh_path, "w")
        for com in i_commands:
            print(prefix + com, file=fout)
        fout.close()
        if call:
            os.system("bash {}&".format(sh_path))
            time.sleep(delay)


def get_loader_from_dataset(dataset, batch_size, seed=1, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=shuffle
    )


def get_unlearn_loader(marked_loader, args):
    forget_dataset = copy.deepcopy(marked_loader.dataset)
    marked = forget_dataset.targets < 0
    forget_dataset.data = forget_dataset.data[marked]
    forget_dataset.targets = -forget_dataset.targets[marked] - 1
    forget_loader = get_loader_from_dataset(
        forget_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=True
    )
    retain_dataset = copy.deepcopy(marked_loader.dataset)
    marked = retain_dataset.targets >= 0
    retain_dataset.data = retain_dataset.data[marked]
    retain_dataset.targets = retain_dataset.targets[marked]
    retain_loader = get_loader_from_dataset(
        retain_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=True
    )
    return forget_loader, retain_loader


def get_poisoned_loader(poison_loader, unpoison_loader, test_loader, poison_func, args):
    poison_dataset = copy.deepcopy(poison_loader.dataset)
    poison_test_dataset = copy.deepcopy(test_loader.dataset)

    poison_dataset.data, poison_dataset.targets = poison_func(
        poison_dataset.data, poison_dataset.targets
    )
    poison_test_dataset.data, poison_test_dataset.targets = poison_func(
        poison_test_dataset.data, poison_test_dataset.targets
    )

    full_dataset = torch.utils.data.ConcatDataset(
        [unpoison_loader.dataset, poison_dataset]
    )

    poisoned_loader = get_loader_from_dataset(
        poison_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=False
    )
    poisoned_full_loader = get_loader_from_dataset(
        full_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=True
    )
    poisoned_test_loader = get_loader_from_dataset(
        poison_test_dataset, batch_size=args.batch_size, seed=args.seed, shuffle=False
    )

    return poisoned_loader, unpoison_loader, poisoned_full_loader, poisoned_test_loader
