""" helper function

author baiyu
"""
import os
import sys
import re
import datetime

import numpy as np

import torch
from torch.optim.lr_scheduler import _LRScheduler
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable

from conf import settings


transfer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=settings.CIFAR100_TRAIN_MEAN, std=settings.CIFAR100_TRAIN_STD)
])


def get_network_assmb(model_name, args):
    """ return given network
    """
    cfg = settings.DATASET_CFG[args.dataset]
    num_cls = cfg['num_cls']
    model_type = cfg['model']
    if model_type != 'image_model':
        model_name = model_type
    if model_name == 'ColumnFC':
        from models.ColumnFC import ColumnFC
        net = ColumnFC(input_dim=cfg['input_dim'], output_dim=num_cls)
    elif model_name == 'mia_fc':
        from models.MIAFC import MIAFC
        net = MIAFC(input_dim=num_cls*2, output_dim=2) 
        # input dim consists of one-hot label and prediction confidence, thus to have input_dim = num_cls*2.
    elif model_name == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_cls=num_cls)
    elif model_name == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_cls=num_cls)
    elif model_name == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_cls=num_cls)
    elif model_name == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_cls=num_cls)
    elif model_name == 'densenet121':
        from models.densenet import densenet121
        net = densenet121(num_cls=num_cls)
    elif model_name == 'densenet161':
        from models.densenet import densenet161
        net = densenet161(num_cls=num_cls)
    elif model_name == 'densenet169':
        from models.densenet import densenet169
        net = densenet169(num_cls=num_cls)
    elif model_name == 'densenet201':
        from models.densenet import densenet201
        net = densenet201(num_cls=num_cls)
    elif model_name == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(num_cls=num_cls)
    elif model_name == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3(num_cls=num_cls)
    elif model_name == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4(num_cls=num_cls)
    elif model_name == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2(num_cls=num_cls)
    elif model_name == 'xception':
        from models.xception import xception
        net = xception(num_cls=num_cls)
    elif model_name == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_cls=num_cls)
    elif model_name == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_cls=num_cls)
    elif model_name == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_cls=num_cls)
    elif model_name == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_cls=num_cls)
    elif model_name == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_cls=num_cls)
    elif model_name == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(num_cls=num_cls)
    elif model_name == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(num_cls=num_cls)
    elif model_name == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(num_cls=num_cls)
    elif model_name == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(num_cls=num_cls)
    elif model_name == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(num_cls=num_cls)
    elif model_name == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(num_cls=num_cls)
    elif model_name == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(num_cls=num_cls)
    elif model_name == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(num_cls=num_cls)
    elif model_name == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet(num_cls=num_cls)
    elif model_name == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2(num_cls=num_cls)
    elif model_name == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet(num_cls=num_cls)
    elif model_name == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_cls=num_cls)
    elif model_name == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(num_cls=num_cls)
    elif model_name == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet(num_cls=num_cls)
    elif model_name == 'attention56':
        from models.attention import attention56
        net = attention56(num_cls=num_cls)
    elif model_name == 'attention92':
        from models.attention import attention92
        net = attention92(num_cls=num_cls)
    elif model_name == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(num_cls=num_cls)
    elif model_name == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(num_cls=num_cls)
    elif model_name == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(num_cls=num_cls)
    elif model_name == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(num_cls=num_cls)
    elif model_name == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(num_cls=num_cls)
    elif model_name == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet(num_cls=num_cls)
    elif model_name == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18(num_cls=num_cls)
    elif model_name == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34(num_cls=num_cls)
    elif model_name == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50(num_cls=num_cls)
    elif model_name == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101(num_cls=num_cls)
    elif 'CNN' in model_name:
        from models.CNN import CNN
        net = CNN(model_name, args.dataset, sigma=args.defense_arg, ldl_defense=True if args.defense=='ldl' else False)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu: #use_gpu
        net = net.cuda()

    return net

def get_training_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    """ return training dataloader
    Args:
        mean: mean of cifar100 training dataset
        std: std of cifar100 training dataset
        path: path to cifar100 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    #cifar100_training = CIFAR100Train(path, transform=transform_train)
    cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    cifar100_training_loader = DataLoader(
        cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_training_loader


def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    """ return training dataloader
    Args:
        mean: mean of cifar100 test dataset
        std: std of cifar100 test dataset
        path: path to cifar100 test python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: cifar100_test_loader:torch dataloader object
    """

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    #cifar100_test = CIFAR100Test(path, transform=transform_test)
    cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_test_loader


def compute_mean_std(cifar100_dataset):
    """compute the mean and std of cifar100 dataset
    Args:
        cifar100_training_dataset or cifar100_test_dataset
        witch derived from class torch.utils.data

    Returns:
        a tuple contains mean, std value of entire dataset
    """

    data_r = np.dstack([cifar100_dataset[i][1][:, :, 0] for i in range(len(cifar100_dataset))])
    data_g = np.dstack([cifar100_dataset[i][1][:, :, 1] for i in range(len(cifar100_dataset))])
    data_b = np.dstack([cifar100_dataset[i][1][:, :, 2] for i in range(len(cifar100_dataset))])
    mean = np.mean(data_r), np.mean(data_g), np.mean(data_b)
    std = np.std(data_r), np.std(data_g), np.std(data_b)

    return mean, std


class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """
    def __init__(self, optimizer, total_iters, last_epoch=-1):

        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]


def test_sample_process(args, target_index):
    """preprocess target sample for membership inference
    Args:
        args: argument parser
        target_index: index of target sample 
    """
    if args.dataset in 'CIFAR10 CIFAR100 gtsrb svhn stl10':
        target_data = np.load(settings.TEST_DATA_PATH.format(dataset=args.dataset))[target_index]
        target_label = np.load(settings.TEST_LABELS_PATH.format(dataset=args.dataset))[target_index]
        ground_label = np.load(settings.TEST_LABELS_PATH.format(dataset=args.dataset))[target_index]
        if args.dataset == 'CIFAR10' or args.dataset == 'CIFAR100':
            r = target_data[:1024].reshape(32, 32)
            g = target_data[1024:2048].reshape(32, 32)
            b = target_data[2048:].reshape(32, 32)
            inp = np.dstack((r, g, b))
        else:
            inp = target_data
        inp = inp.astype(np.uint8)
        inp = transfer(inp)

        
    elif args.dataset in 'location texas100 purchase100':
        target_data = np.load(settings.TEST_DATA_PATH.format(dataset=args.dataset))[target_index]
        target_label = np.load(settings.TEST_LABELS_PATH.format(dataset=args.dataset))[target_index]
        ground_label = np.load(settings.TEST_LABELS_PATH.format(dataset=args.dataset))[target_index]
        inp = torch.tensor(target_data)
    
    inp = Variable(inp.type(torch.FloatTensor).cuda().unsqueeze(0), requires_grad=True)
    ori = inp.clone()

    target_label = np.array([target_label])
    target_label = torch.from_numpy(target_label)
    target_label = target_label.type(torch.LongTensor)
    target_label = target_label.cuda()

    return inp, ori, ground_label, target_label
