import torch
import torch.nn as nn
from torch.utils.data import SubsetRandomSampler
from torchvision import datasets, transforms
import torch.nn.init as init
import os
import numpy as np
import math
from collections import OrderedDict

from models import *


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)


def load_model(args, device, model_path=None, name='', single_gpu=False):
    if args.dataset == 'mnist':
        if args.arch == 'lenet':
            model = LeNet()
        else:
            raise NotImplementedError
    elif args.dataset == 'cifar10':
        if args.arch == 'resnet18':
            model = ResNet18(args.dataset)
        elif args.arch == 'resnet32':
            model = resnet32()
        else:
            raise NotImplementedError
    elif args.dataset == 'tiny-imagenet':
        if args.arch == 'mobilenetv2':
            model = MobileNetV2(args.dataset)
        elif args.arch == 'resnet50':
            model = ResNet50(args.dataset)
        elif args.arch == 'efficientnet-b0':
            model = efficientnet.EfficientNet.from_name('efficientnet-b0', image_size=64, num_classes=200)
    else:
        raise NotImplementedError

    if args.single_gpu or single_gpu:   # Optional Single GPU (required for SNIP)
        model = model.to(device)
    else:
        model = nn.DataParallel(model).to(device)
    
    if model_path is not None:
        if args.single_gpu:
            single_state_dict = OrderedDict()
            for k, v in torch.load(model_path).items():
                name = k
                if 'module.' in name:
                    name = name[7:] # remove `module.`
                single_state_dict[name] = v
            model.load_state_dict(single_state_dict)
        else:
            model.load_state_dict(torch.load(model_path))
            print(f"Load{' ' + name + ' ' if name else ' ' }model from {model_path}")

    return model


def load_dataset(args, kwargs):
    validset = None

    if args.dataset == 'mnist':
        transform_train=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        transform_test=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        trainset = datasets.MNIST(args.data_dir, train=True, download=True)
        testset = datasets.MNIST(args.data_dir, train=False, transform=transform_test)

        data_shape, num_classes = (1, 28, 28), 10

    elif args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root=args.data_dir, train=True, download=True)
        testset = datasets.CIFAR10(root=args.data_dir, train=False, transform=transform_test)

        data_shape, num_classes = (3, 32, 32), 10

    elif args.dataset == 'tiny-imagenet':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(64),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        tiny_dir = os.path.join(args.data_dir, 'tiny-imagenet-200')
        trainset = datasets.ImageFolder(os.path.join(tiny_dir, 'train'))
        testset = datasets.ImageFolder(os.path.join(tiny_dir, 'val'), transform=transform_test)

        data_shape, num_classes = (3, 64, 64), 200
    else:
        raise ValueError

    valid_len = int(len(trainset) * args.valid_ratio)
    train_len = len(trainset) - valid_len

    train_subset, valid_subset = torch.utils.data.random_split(trainset, [train_len, valid_len])
    trainset = MyDataset(train_subset, transform=transform_train)
    validset = MyDataset(valid_subset, transform=transform_test)

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                            shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(validset, batch_size=args.test_batch_size,
                                            shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
                                            shuffle=False, **kwargs)

    return train_loader, valid_loader, test_loader, data_shape, num_classes


def get_save_name(args, prev_name=False, imp_percent=0):
    save_name = f"{args.dataset}_{args.arch}"

    # Pruning Info
    if args.mode == 'prune':
        # All-Alive Pruning
        if args.all_alive_pruning:
            save_name += '_alive'

        # Pruning method
        save_name += '_' + args.prune_method

        # IMP
        if args.imp:
            imp_index = args.imp_index -1 if prev_name else args.imp_index
            prune_info = f'IMP{imp_percent:.3f}_{imp_index}'
        else:
            prune_info = f'ONE_{args.prune_percent:.3f}'
        save_name += '_' + prune_info

    # Optimizer / Random Seed
    save_name += f'_{args.optim}_S{args.seed}'

    return save_name


# Print the ratio of the nonzero parameters in the model
def print_pruned(model):
    nonzero = total = 0

    print("-" * 80)
    print(f'{"Params":40} | {"# alive / # params":27} | {"Pruned":7}')
    print("-" * 80)
    for name, p in model.named_parameters():
        tensor = p.data.cpu().numpy()
        nz_count = np.count_nonzero(tensor)
        total_params = np.prod(tensor.shape)
        nonzero += nz_count
        total += total_params
        print(f'{name:40} | {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | {total_params - nz_count :7}')
    
    print("-" * 80)
    print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:.2f}x  ({100 * (total-nonzero) / total:.2f}% pruned)')
    print("-" * 80)

    return nonzero, total - nonzero


# Preventing floating point error
def get_pruning_ratio(rate, step):
    final_rate = 100 - (100 - rate) ** step / 100 ** (step - 1)
    return final_rate