import copy
import os
import random

# from advertorch.utils import NormalizeByChannelMeanStd
import shutil
import sys
import time

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

from .dataset import TinyImageNet,cifar10_dataloaders,svhn_dataloaders
# from imagenet import prepare_data
from .resnet import resnet18
def prepare_dataset(args):

    (
    model,
    train_loader_full,
    val_loader,
    test_loader,
    marked_loader,
    ) = setup_model_dataset(args)

    dataset_all = {'train':train_loader_full,'val':val_loader,'test':test_loader,**split_train_dataset(marked_loader,dataset_name=args.dataset)}
    if args.class_to_replace is not None:
        dataset_all.update(**split_train_dataset(test_loader,marked_loader_name='test_',dataset_name=args.dataset))
        new_test_dataset = copy.deepcopy(test_loader.dataset)
        if hasattr(new_test_dataset,'targets'):
            markged_indices = new_test_dataset.targets<0
            new_test_dataset.targets[markged_indices] = np.abs(new_test_dataset.targets[markged_indices]+1)
        else:
            markged_indices = new_test_dataset.labels<0
            new_test_dataset.labels[markged_indices] = np.abs(new_test_dataset.labels[markged_indices]+1)
        dataset_all['test'] = replace_loader_dataset(new_test_dataset,test_loader.batch_size)

    for k, loader in dataset_all.items():
        dataset_convert_to_test(loader)
    return dataset_all, model

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def replace_loader_dataset(
        dataset, batch_size, seed=1, shuffle=True
    ):
        setup_seed(seed)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=4,
            pin_memory=True,
            shuffle=shuffle,
        )
def split_train_dataset_svhn(marked_loader:DataLoader,seed:int =2,marked_loader_name=''):
    forget_dataset = copy.deepcopy(marked_loader.dataset)
    try:
        marked = forget_dataset.targets < 0
    except:
        marked = forget_dataset.labels < 0
    forget_dataset.data = forget_dataset.data[marked]
    try:
        forget_dataset.targets = -forget_dataset.targets[marked] - 1
    except:
        forget_dataset.labels = -forget_dataset.labels[marked] - 1
    forget_loader = replace_loader_dataset(forget_dataset, seed=seed, shuffle=True)
    retain_dataset = copy.deepcopy(marked_loader.dataset)
    try:
        marked = retain_dataset.targets >= 0
    except:
        marked = retain_dataset.labels >= 0
    retain_dataset.data = retain_dataset.data[marked]
    try:
        retain_dataset.targets = retain_dataset.targets[marked]
    except:
        retain_dataset.labels = retain_dataset.labels[marked]
    retain_loader = replace_loader_dataset(retain_dataset, seed=seed, shuffle=True)
    assert len(forget_dataset) + len(retain_dataset) == len(marked_loader.dataset)
    return {marked_loader_name+'forget':forget_loader,marked_loader_name+'retain':retain_loader}
def split_train_dataset_cifar10(marked_loader:DataLoader,seed:int =2,marked_loader_name=''):
    forget_dataset = copy.deepcopy(marked_loader.dataset)
    try:
        marked = forget_dataset.targets < 0
        forget_dataset.data = forget_dataset.data[marked]
        forget_dataset.targets = -forget_dataset.targets[marked] - 1
        forget_loader = replace_loader_dataset(
            forget_dataset,marked_loader.batch_size, seed=seed, shuffle=True
        )
        retain_dataset = copy.deepcopy(marked_loader.dataset)
        marked = retain_dataset.targets >= 0
        retain_dataset.data = retain_dataset.data[marked]
        retain_dataset.targets = retain_dataset.targets[marked]
        retain_loader = replace_loader_dataset(
            retain_dataset,marked_loader.batch_size, seed=seed, shuffle=True
        )
        assert len(forget_dataset) + len(retain_dataset) == len(
            marked_loader.dataset
        )
    except:
        marked = forget_dataset.targets < 0
        forget_dataset.imgs = forget_dataset.imgs[marked]
        forget_dataset.targets = -forget_dataset.targets[marked] - 1
        forget_loader = replace_loader_dataset(
            forget_dataset,marked_loader.batch_size, seed=seed, shuffle=True
        )
        retain_dataset = copy.deepcopy(marked_loader.dataset)
        marked = retain_dataset.targets >= 0
        retain_dataset.imgs = retain_dataset.imgs[marked]
        retain_dataset.targets = retain_dataset.targets[marked]
        retain_loader = replace_loader_dataset(
            retain_dataset,marked_loader.batch_size, seed=seed, shuffle=True
        )
        assert len(forget_dataset) + len(retain_dataset) == len(
            marked_loader.dataset
            )
    return {marked_loader_name+'forget':forget_loader,marked_loader_name+'retain':retain_loader}
def split_train_dataset(marked_loader:DataLoader,seed:int =2,marked_loader_name='',dataset_name='cifar10'):
    if dataset_name == 'cifar10':
        return split_train_dataset_cifar10(marked_loader,seed=seed,marked_loader_name=marked_loader_name)
    elif dataset_name == 'svhn':
        return split_train_dataset_svhn(marked_loader,seed=seed,marked_loader_name=marked_loader_name)
    else:
        raise NotImplementedError(f'dataset {dataset_name} not implemented')
def dataset_convert_to_test(dataset):

    test_transform = [
        transforms.ToTensor(),
        ]
        
    while hasattr(dataset, "dataset"):
        dataset = dataset.dataset
    dataset.transform = transforms.Compose(test_transform)
    dataset.train = False
class NormalizeByChannelMeanStd(torch.nn.Module):
    def __init__(self, mean, std):
        super(NormalizeByChannelMeanStd, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, tensor):
        return self.normalize_fn(tensor, self.mean, self.std)

    def extra_repr(self):
        return "mean={}, std={}".format(self.mean, self.std)

    def normalize_fn(self, tensor, mean, std):
        """Differentiable version of torchvision.functional.normalize"""
        # here we assume the color channel is in at dim=1
        mean = mean[None, :, None, None]
        std = std[None, :, None, None]
        return tensor.sub(mean).div(std)
def setup_seed(seed):
    print("setup random seed = {}".format(seed))
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
def setup_model_dataset(args):
    if args.dataset == "cifar10":
        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
        )
        train_full_loader, val_loader, _ = cifar10_dataloaders(arch=args.arch,
            batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
        )
        marked_loader, _, test_loader = cifar10_dataloaders(arch=args.arch,
            batch_size=args.batch_size,
            data_dir=args.data,
            num_workers=args.workers,
            class_to_replace=args.class_to_replace,
            num_indexes_to_replace=args.num_indexes_to_replace,
            indexes_to_replace=args.indexes_to_replace,
            seed=args.seed,
            only_mark=True,
            shuffle=True,
            no_aug=args.no_aug,
        )

        if args.train_seed is None:
            args.train_seed = args.seed
        setup_seed(args.train_seed)

        if args.imagenet_arch:
            model = resnet18(num_classes=classes, imagenet=True)
        else:
            model = resnet18(num_classes=classes)
        setup_seed(args.train_seed)
        model.normalize = normalization
        return model, train_full_loader, val_loader, test_loader, marked_loader
        

    elif args.dataset == "svhn":
        classes = 10
        normalization = NormalizeByChannelMeanStd(
            mean=[0.4377, 0.4438, 0.4728], std=[0.1201, 0.1231, 0.1052]
        )
        train_full_loader, val_loader, _ = svhn_dataloaders(arch=args.arch,
            batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
        )
        marked_loader, _, test_loader = svhn_dataloaders(arch=args.arch,
            batch_size=args.batch_size,
            data_dir=args.data,
            num_workers=args.workers,
            class_to_replace=args.class_to_replace,
            num_indexes_to_replace=args.num_indexes_to_replace,
            indexes_to_replace=args.indexes_to_replace,
            seed=args.seed,
            only_mark=True,
            shuffle=True,
        )
    
        # if args.imagenet_arch:
        #     model = model_dict[args.arch](num_classes=classes, imagenet=True)
        # else:
        #     model = model_dict[args.arch](num_classes=classes)

        # model.normalize = normalization
        return model, train_full_loader, val_loader, test_loader, marked_loader
        
    # elif args.dataset == "cifar100":
    #     classes = 100
    #     normalization = NormalizeByChannelMeanStd(
    #         mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
    #     )
    #     train_full_loader, val_loader, _ = cifar100_dataloaders(arch=args.arch,
    #         batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
    #     )
    #     marked_loader, _, test_loader = cifar100_dataloaders(arch=args.arch,
    #         batch_size=args.batch_size,
    #         data_dir=args.data,
    #         num_workers=args.workers,
    #         class_to_replace=args.class_to_replace,
    #         num_indexes_to_replace=args.num_indexes_to_replace,
    #         indexes_to_replace=args.indexes_to_replace,
    #         seed=args.seed,
    #         only_mark=True,
    #         shuffle=True,
    #         no_aug=args.no_aug,
    #     )
    #     if args.autoencoder:
    #         # from models.ResNet_Encoder import ResNetAutoEncoder
    #         from models.Unet import UNetWithResNetEncoder, UNetWithResNetEncoderV2, UNetWithResNetEncoderV3, UNetWithResNetEncoderV4
    #         from models.ResNet_Encoder import ResNetAutoEncoder
    #         if args.imagenet_arch:
    #             encoder = model_dict[args.arch](num_classes=classes, imagenet=True)
    #         else:
    #             encoder = model_dict[args.arch](num_classes=classes)
    #         setup_seed(args.train_seed)
    #         encoder.normalize = normalization

    #         # model = ResNetAutoEncoder(encoder=encoder,)
    #         # model = UNetWithResNetEncoder(resnet_type='resnet18',num_classes=classes, encoder_trainable=False)
    #         model = UNetWithResNetEncoderV4(resnet_type=args.arch,num_classes=classes,encoder_trainable=False)
    #         # model = ResNetAutoDecoder(resnet_type=args.arch,num_classes=classes,encoder_trainable=False)
    #         return model, train_full_loader, val_loader, test_loader, marked_loader
    #     else:
    #         if args.imagenet_arch:
    #             model = model_dict[args.arch](num_classes=classes, imagenet=True)
    #         else:
    #             model = model_dict[args.arch](num_classes=classes)
    #         setup_seed(args.train_seed)
    #         model.normalize = normalization
    #         return model, train_full_loader, val_loader, test_loader, marked_loader
        
    # elif args.dataset == "TinyImagenet":
    #     classes = 200
    #     normalization = NormalizeByChannelMeanStd(
    #         mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    #     )
    #     train_full_loader, val_loader, test_loader = TinyImageNet(args).data_loaders(arch=args.arch,
    #         batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
    #     )
    #     # train_full_loader, val_loader, test_loader =None, None,None
    #     marked_loader, _, _ = TinyImageNet(args).data_loaders(arch=args.arch,
    #         batch_size=args.batch_size,
    #         data_dir=args.data,
    #         num_workers=args.workers,
    #         class_to_replace=args.class_to_replace,
    #         num_indexes_to_replace=args.num_indexes_to_replace,
    #         indexes_to_replace=args.indexes_to_replace,
    #         seed=args.seed,
    #         only_mark=True,
    #         shuffle=True,
    #     )
    #     if args.imagenet_arch:
    #         model = model_dict[args.arch](num_classes=classes, imagenet=True)
    #     else:
    #         model = model_dict[args.arch](num_classes=classes)

    #     model.normalize = normalization
    #     return model, train_full_loader, val_loader, test_loader, marked_loader

    # elif args.dataset == "imagenet":
    #     classes = 1000
    #     normalization = NormalizeByChannelMeanStd(
    #         mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    #     )
    #     train_ys = torch.load(args.train_y_file)
    #     val_ys = torch.load(args.val_y_file)
    #     model = model_dict[args.arch](num_classes=classes, imagenet=True)

    #     model.normalize = normalization
    #     if args.class_to_replace is None:
    #         loaders = prepare_data(dataset="imagenet", batch_size=args.batch_size)
    #         train_loader, val_loader = loaders["train"], loaders["val"]
    #         return model, train_loader, val_loader
    #     else:
    #         train_subset_indices = torch.ones_like(train_ys)
    #         val_subset_indices = torch.ones_like(val_ys)
    #         train_subset_indices[train_ys == args.class_to_replace] = 0
    #         val_subset_indices[val_ys == args.class_to_replace] = 0
    #         loaders = prepare_data(
    #             dataset="imagenet",
    #             batch_size=args.batch_size,
    #             train_subset_indices=train_subset_indices,
    #             val_subset_indices=val_subset_indices,
    #         )
    #         retain_loader = loaders["train"]
    #         forget_loader = loaders["fog"]
    #         val_loader = loaders["val"]
    #         return model, retain_loader, forget_loader, val_loader

    # elif args.dataset == "cifar100_no_val":
    #     classes = 100
    #     normalization = NormalizeByChannelMeanStd(
    #         mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
    #     )
    #     train_set_loader, val_loader, test_loader = cifar100_dataloaders_no_val(
    #         batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
    #     )

    # elif args.dataset == "cifar10_no_val":
    #     classes = 10
    #     normalization = NormalizeByChannelMeanStd(
    #         mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
    #     )
    #     train_set_loader, val_loader, test_loader = cifar10_dataloaders_no_val(
    #         batch_size=args.batch_size, data_dir=args.data, num_workers=args.workers
    #     )

    else:
        raise ValueError("Dataset not supprot yet !")
    # import pdb;pdb.set_trace()

    # if args.imagenet_arch:
    #     model = model_dict[args.arch](num_classes=classes, imagenet=True)
    # else:
    #     model = model_dict[args.arch](num_classes=classes)

    # model.normalize = normalization
    # return model, train_set_loader, val_loader, test_loader


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred)).contiguous()

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res