import logging
import numpy as np
import torch
import random
import datasets
import models
import torch.nn as nn
import attack
# save checkpoint


# loggings
def setup_logging(log_file='log.txt', filemode='w'):
    """
    Setup logging configuration
    """
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s - %(levelname)s - %(message)s",
                        datefmt="%Y-%m-%d %H:%M:%S",
                        filename=log_file,
                        filemode=filemode)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(message)s', datefmt='%m/%d %I:%M:%S %p')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)


# set seed
def set_seed(seed=42):
    np.random.seed(seed)
    # sets the seed for generating random numbers.
    torch.manual_seed(seed)
    # Sets the seed for generating random numbers for the current GPU.
    # It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
    torch.cuda.manual_seed(seed)
    # Sets the seed for generating random numbers on all GPUs.
    # It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)


# compute the accuracy of top-k
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


# Average info
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


# norm_layer
class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.register_buffer('mean', torch.Tensor(mean))
        self.register_buffer('std', torch.Tensor(std))
        self.c = len(mean)

    def forward(self, input):
        # Broadcasting
        device = input.device
        mean = self.mean.reshape(1, self.c, 1, 1).to(device)
        std = self.std.reshape(1, self.c, 1, 1).to(device)
        return (input - mean) / std



# dataset
def get_dataset(dataset_cfg, normalize):
    if dataset_cfg.name == "MNIST":
        train_loader = datasets.MNIST_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, train=True, normalize=normalize)
        test_loader = datasets.MNIST_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, train=False, normalize=normalize)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    elif dataset_cfg.name == "SVHN":
        train_loader, _, test_loader = datasets.SVHN_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, normalize=normalize)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    elif dataset_cfg.name == "CIFAR10":
        train_loader, _, test_loader = datasets.CIFAR10_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, normalize=normalize)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
        # norm_layer = Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))
    elif dataset_cfg.name == "CIFAR100":
        train_loader, _, test_loader = datasets.CIFAR100_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, normalize=normalize)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    elif dataset_cfg.name == "TIN":
        train_loader, _, test_loader = datasets.TIN_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, normalize=normalize)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    elif dataset_cfg.name == "GTSRB":
        train_loader = datasets.GTSRB_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, train=True)
        test_loader = datasets.GTSRB_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, train=False)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    else:
        raise ValueError("Invalid dataset")
    
    return train_loader, test_loader, norm_layer


def get_poisoned_dataset(dataset_cfg, poison_rate, epsilon=8, clean_label=False, attack='PPT'):
    if dataset_cfg.name == "MNIST":
        train_loader, test_loader = datasets.MNIST_poisoned_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, poison_rate=poison_rate, 
                                                                       epsilon=epsilon, clean_label=clean_label, attack=attack)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    elif dataset_cfg.name == "SVHN":
        train_loader, test_loader = datasets.SVHN_poisoned_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, poison_rate=poison_rate, 
                                                                      epsilon=epsilon, clean_label=clean_label, attack=attack)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    elif dataset_cfg.name == "CIFAR10":
        train_loader, test_loader = datasets.CIFAR10_poisoned_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, poison_rate=poison_rate,
                                                                         epsilon=epsilon, clean_label=clean_label, attack=attack, generator=dataset_cfg.generator)
    elif dataset_cfg.name == "CIFAR100":
        pass
        # train_loader, _, test_loader = datasets.CIFAR100_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, normalize=normalize)
        # norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    elif dataset_cfg.name == "TIN":
        pass
        train_loader, test_loader = datasets.TIN_poisoned_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, poison_rate=poison_rate, 
                                                                     epsilon=epsilon, clean_label=clean_label, attack=attack)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    elif dataset_cfg.name == "GTSRB":
        train_loader, test_loader = datasets.GTSRB_poisoned_dataloader(dataset_cfg.dir, dataset_cfg.batch_size, poison_rate=poison_rate, 
                                                                       epsilon=epsilon, clean_label=clean_label, attack=attack)
        norm_layer = Normalize(dataset_cfg.mean, dataset_cfg.std)
    else:
        raise ValueError("Invalid dataset")
    
    return train_loader, test_loader


# model
def get_model(name, num_classes):
    if name == "resnet18":
        return models.ResNet18(num_classes=num_classes)
    elif name == "resnet34":
        return models.ResNet34(num_classes=num_classes)
    elif name == "mobilev2":
        return models.MobileNetV2(num_classes=num_classes)
    elif name == "vgg11bn":
        return models.vgg11_bn(num_classes=num_classes)
    elif name == "preresnet18":
        return models.PreActResNet18(num_classes=num_classes)
    elif name == "netc": # for MNIST
        return models.NetC_MNIST()
    elif name == "svhn_net":
        return models.svhn()
    elif name == "efficientB0":
        return models.EfficientNetB0()
    else:
        raise ValueError("Invalid model")
    

# attack
def get_adv(model, input, target, norm_layer, config):
    if config.name == 'fgsm':
        adv_input = attack.fgsm(model, input, target, norm_layer, rs=config.random_start,
                                epsilon=config.epsilon, targeted=config.targeted)
    elif config.name == 'fgsm_l2':
        adv_input = attack.fgsm_l2(model, input, target, norm_layer, rs=config.random_start,
                                   epsilon=config.epsilon, targeted=config.targeted)
    
    elif config.name == 'pgd':
        adv_input = attack.pgd(model, input, target, norm_layer, targeted=config.targeted, rs=config.random_start,
                               epsilon=config.epsilon, attack_iters=config.iterations, restarts=config.restarts)
        
    elif config.name == 'pgd_l2':
        adv_input = attack.pgd_l2(model, input, target, norm_layer, targeted=config.targeted, rs=config.random_start,
                               epsilon=config.epsilon, attack_iters=config.iterations, restarts=config.restarts)
    
    else:
        raise ValueError("Wrong attack method")
    
    return adv_input