import torch
import numpy as np
import torch.optim as optim
from Utils import metrics
from torchvision import datasets, transforms
from Layers import layers
from Models import *
from Pruners import pruners
from Utils import custom_datasets


def device(gpu):
    use_cuda = torch.cuda.is_available()
    return torch.device(("cuda:" + str(gpu)) if use_cuda else "cpu")


def dimension(dataset):
    if dataset == 'mnist':
        input_shape, num_classes = (1, 28, 28), 10
    if dataset == 'cifar10':
        input_shape, num_classes = (3, 32, 32), 10
    if dataset == 'cifar100':
        input_shape, num_classes = (3, 32, 32), 100
    if dataset == 'tiny-imagenet':
        input_shape, num_classes = (3, 64, 64), 200
    if dataset == 'imagenet' or dataset == 'speed_test':
        input_shape, num_classes = (3, 224, 224), 1000
    return input_shape, num_classes


def get_transform(size, padding, mean, std, preprocess):
    transform = []
    if preprocess:
        transform.append(transforms.RandomCrop(size=size, padding=padding))
        transform.append(transforms.RandomHorizontalFlip())
    transform.append(transforms.ToTensor())
    transform.append(transforms.Normalize(mean, std))
    return transforms.Compose(transform)


def dataloader(dataset, batch_size, train, workers, length=None, data_dir=None):
    # Dataset
    if dataset == 'mnist':
        mean, std = (0.1307, ), (0.3081, )
        transform = get_transform(size=28, padding=0, mean=mean, std=std, preprocess=False)
        dataset = datasets.MNIST('Data', train=train, download=True, transform=transform)
    if dataset == 'cifar10':
        mean, std = (0.491, 0.482, 0.447), (0.247, 0.243, 0.262)
        if train:
            transform = get_transform(size=32, padding=4, mean=mean, std=std, preprocess=train)
        else:
            normalize = transforms.Normalize(mean=mean, std=std)
            transform = transforms.Compose([transforms.ToTensor(), normalize])
        dataset = datasets.CIFAR10('Data', train=train, download=False, transform=transform)
    if dataset == 'cifar100':
        mean, std = (0.507, 0.487, 0.441), (0.267, 0.256, 0.276)
        transform = get_transform(size=32, padding=4, mean=mean, std=std, preprocess=train)

        dataset = datasets.CIFAR100('Data', train=train, download=True, transform=transform)
    if dataset == 'tiny-imagenet':
        mean, std = (0.480, 0.448, 0.397), (0.276, 0.269, 0.282)
        transform = get_transform(size=64, padding=4, mean=mean, std=std, preprocess=train)
        dataset = custom_datasets.TINYIMAGENET('Data', train=train, download=True, transform=transform)
    if dataset == 'imagenet':
        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        if train:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                # transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                # transforms.RandomGrayscale(p=0.2),
                # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        else:
            transform = transforms.Compose(
                [transforms.Resize(256), transforms.CenterCrop(224),
                 transforms.ToTensor(), transforms.Normalize(mean, std)])
        if train:
            folder = data_dir + '/train'
        else:
            folder = data_dir + '/val'
        dataset = datasets.ImageFolder(folder, transform=transform)
    if dataset == 'speed_test':
        dataset = custom_datasets.dummy_dataset()

    # Dataloader
    use_cuda = torch.cuda.is_available()
    kwargs = {'num_workers': workers, 'pin_memory': True} if use_cuda else {}
    shuffle = train is True
    if length is not None:
        indices = torch.randperm(len(dataset))[:length]
        dataset = torch.utils.data.Subset(dataset, indices)

    # dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

    return dataloader


def model(model_architecture, dataset):
    if dataset == 'speed_test':
        dataset = 'imagenet'
    cifar_models = {
        'resnet20': cifar_resnet.resnet20_cifar10,
        'resnet56': cifar_resnet.resnet56_cifar10,
        'shufflenetV1g3': cifar_shufflenet.shufflenet_g3,
        'shufflenetV1g8': cifar_shufflenet.shufflenet_g8,
        'mobilenetV2': cifar_mobilenetv2.mobilenetV2,
        'vgg16bn': cifar_vgg.vgg16_bn,
        'mlp_fc': cifar_mlp.mlp_fc6,
        'mlp_conv': cifar_mlp.mlp_conv6
    }
    imagenet_models = {
        'resnet18': imagenet_resnet.resnet18,
        'resnet34': imagenet_resnet.resnet34,
        'resnet50': imagenet_resnet.resnet50,
        'shufflenetV1g3': imagenet_shufflenet.shufflenet_g3,
        'shufflenetV1g8': imagenet_shufflenet.shufflenet_g8,
        'mobilenetV2': imagenet_mobilenetv2.mobilenetV2,
        'effcientnetb0': imagenet_effnet.efficientnet_b0,
        'effcientnetb2': imagenet_effnet.efficientnet_b2
    }
    models = {'cifar10': cifar_models, 'imagenet': imagenet_models}
    return models[dataset][model_architecture]


def pruner(method):
    prune_methods = {
        'synflow': pruners.SynFlow,
        'snip': pruners.SNIP,
        'grasp': pruners.GraSP,
        'lottery': pruners.Mag,
        'opt_params': pruners.SynFlow,
        'opt_flops': pruners.SynFlow,
        'opt_both': pruners.SynFlow,
    }
    return prune_methods[method]


def optimizer(optimizer):
    optimizers = {
        'adam': (optim.Adam, {}),
        'sgd': (optim.SGD, {}),
        'momentum': (optim.SGD, {
            'momentum': 0.9,
            'nesterov': True
        }),
        'rms': (optim.RMSprop, {})
    }
    return optimizers[optimizer]


def scheduler(args, optimizer, train_loader):
    if args.lr_scheduler == 'drop' or args.post_epochs == 0:
        if args.lr_step_size == 0:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)
        else:
            lr_drops = []
            for epoch in range(args.post_epochs):
                if epoch % args.lr_step_size == 0:
                    lr_drops.append(epoch)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_drops, gamma=args.lr_drop_rate)
    elif args.lr_scheduler == 'linear':
        total_iters = train_loader.__len__() * args.post_epochs
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lambda iters: (1.0 - iters / total_iters) if iters <= total_iters else 0,
                                                      last_epoch=-1)
    return scheduler


def get_params(model, prune_pw_only, prune_linear):
    params = []
    for module in model.modules():
        if (isinstance(module, layers.Conv2d) and (module.kernel_size == (1, 1) or not prune_pw_only)) or (isinstance(module, layers.Linear) and prune_linear):
            params.append(module.weight.numel())
    return params


def get_flops(model, prune_pw_only, prune_linear, input_shape, device):
    flops = metrics.flop(model, input_shape, device, pw_only=prune_pw_only)
    flops = [flops[k]['weight'] for k in flops if (k.endswith('.conv') or (prune_linear and (k.endswith('fc') or k.endswith('output')))) and 'weight' in flops[k]]
    return flops
