import random
import numpy as np
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from copy import deepcopy
from torch.func import functional_call


def get_train_valid_loader(data_dir, batch_size, download=True):
    cifar_norm_mean = (0.49139968, 0.48215827, 0.44653124)
    cifar_norm_std = (0.24703233, 0.24348505, 0.26158768)
    transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                          transforms.Resize((32,32)),
                                          transforms.ToTensor(),
                                          transforms.Normalize(cifar_norm_mean, cifar_norm_std)])
    transform_valid = transforms.Compose([transforms.Resize((32,32)),
                                          transforms.ToTensor(),
                                          transforms.Normalize(cifar_norm_mean, cifar_norm_std)])
    train_set = datasets.CIFAR10(root=data_dir, train=True, transform=transform_train, download=download)
    valid_set = datasets.CIFAR10(root=data_dir, train=False, transform=transform_valid, download=False)
    trainloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
    return trainloader, testloader


def validation(net, valid_loader, criterion, device):
    valid_loss, valid_accuracy, valid_count = 0., 0., 0
    net.eval()
    for inputs, labels in tqdm(valid_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        valid_accuracy += torch.sum(torch.where(labels == torch.argmax(outputs, dim=1), 1, 0)).cpu().detach().numpy()
        valid_loss += torch.sum(loss).cpu().detach().numpy()
        valid_count += inputs.shape[0]
    net.train()
    valid_loss /= valid_count
    valid_accuracy /= valid_count
    return valid_loss, valid_accuracy

class FakeBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features: int):
        super().__init__(num_features)

    def forward(self, x):
        return self.weight.view(1,-1,1,1) * x

def replace_module(model, target, alternative=None):
    modules_to_replace = []
    for name, module in model.named_modules():
        if isinstance(module, target):
            modules_to_replace.append((name, module))
    for name, module in modules_to_replace:
        if isinstance(module, nn.BatchNorm2d):
            alternative = FakeBatchNorm2d(module.num_features)
        parent_module = model
        *parent_path, child_name = name.split(".")
        for sub_name in parent_path:
            parent_module = getattr(parent_module, sub_name)
        setattr(parent_module, child_name, deepcopy(alternative))
    return model

def para_normalization(x):
    x = torch.abs(x)
    return x / torch.sum(x) * np.log(x.numel()) 


class VGG11(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG11, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(512 * 1 * 1, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.feature(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def l1_norm(self):
        total_l1_norm = 0.0
        for p in self.parameters():
            total_l1_norm += torch.sum(torch.abs(p))
        return total_l1_norm
    
    def connectivity(self):
        x = torch.ones((1, 3, 32, 32), device=next(self.parameters()).device).float()
        skeleton = deepcopy(self)
        skeleton = replace_module(skeleton, target=nn.BatchNorm2d, alternative=FakeBatchNorm2d(10))
        skeleton = replace_module(skeleton, target=(nn.ReLU, nn.Dropout), alternative=nn.Identity())
        skeleton = replace_module(skeleton, target=nn.MaxPool2d, alternative=nn.AvgPool2d(kernel_size=2, stride=2))
        skeleton.to(next(self.parameters()).device)
        params = {}
        for k, v in self.named_parameters():
            if 'bias' in k:
                params[k] = torch.zeros_like(v).to(next(self.parameters()).device)
            else:
                params[k] = para_normalization(v)
        connectivity = torch.sum(functional_call(skeleton, params, (x,)))
        return connectivity

def prune(net, ratio):
    x = torch.ones((1, 3, 32, 32), device=next(net.parameters()).device).float()
    skeleton = deepcopy(net)
    skeleton = replace_module(skeleton, target=nn.BatchNorm2d, alternative=FakeBatchNorm2d(10))
    skeleton = replace_module(skeleton, target=(nn.ReLU, nn.Dropout), alternative=nn.Identity())
    skeleton = replace_module(skeleton, target=nn.MaxPool2d, alternative=nn.AvgPool2d(kernel_size=2, stride=2))
    skeleton.to(next(net.parameters()).device)
    for k, v in skeleton.named_parameters():
        if 'bias' in k:
            v.data = torch.zeros_like(v).to(next(net.parameters()).device)
        else:
            v.data = para_normalization(v)
        v.grad = None
    obj = torch.log(torch.sum(skeleton(x)))
    obj.backward()

    mask_set = []
    all_values = []
    for module, module_s in zip(net.modules(), skeleton.modules()):
        if module == net.feature[1]:
            continue
        if isinstance(module, nn.BatchNorm2d):
            grad_times_param = module_s.weight.grad * module_s.weight
            all_values.append(grad_times_param.view(-1))
    all_values = torch.cat(all_values)
    sorted_values, _ = torch.sort(all_values)
    num_to_reset = int(len(sorted_values) * ratio)
    threshold = sorted_values[num_to_reset - 1] if num_to_reset > 0 else float('-inf')
    for module, module_s in zip(net.modules(), skeleton.modules()):
        if module == net.feature[1]:
            continue
        if isinstance(module, nn.BatchNorm2d):
            grad_times_param = module_s.weight.grad * module_s.weight
            mask = grad_times_param < threshold
            mask_set.append(deepcopy(mask))
            module.weight.data[mask] = 0.
    return mask_set


def run(mode, lamda, weight_decay=0.):
    device = torch.device("cuda:0")
    batch_size = 512
    epochs = 20
    epochs_fine = 5

    net = VGG11().to(device)
    train_loader, valid_loader = get_train_valid_loader(data_dir='../dataset/', batch_size=batch_size)
    optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
    criterion = nn.CrossEntropyLoss()

    net.train()
    for epoch in range(scheduler.last_epoch, epochs):
        train_loss, train_accuracy, train_count = 0., 0., 0
        log_connect, l1_norm = 0., 0.
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for i, (inputs, labels) in pbar:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            con = net.connectivity()
            l1 = net.l1_norm()

            if mode == 0:
                total_loss = loss
            elif mode == 1:
                total_loss = loss + lamda * l1
            else:
                total_loss = loss - lamda * torch.log(con)
            total_loss.backward()

            train_accuracy += torch.sum(torch.where(labels == torch.argmax(outputs, 1), 1, 0)).cpu().detach().numpy()
            train_loss += torch.sum(loss).cpu().detach().numpy()
            train_count += len(labels)
            log_connect += torch.log(con).cpu().detach().numpy()
            l1_norm += l1.cpu().detach().numpy()

            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
            optimizer.step()

        scheduler.step()

        train_loss /= train_count
        train_accuracy /= train_count
        log_connect /= len(train_loader)
        l1_norm /= len(train_loader)
        valid_loss, valid_accuracy = validation(net, valid_loader, criterion, device)
        log_info = f'Train Epoch:{epoch:3d} || train loss:{train_loss:.2e} train accuracy:{train_accuracy*100:.2f}% ' + \
                   f'valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy*100:.2f}% lr:{scheduler.get_last_lr()[0]:.2e} ' + \
                   f'log connect:{log_connect:.4e} l1 norm:{l1:.4e}'
        print(log_info)

    # test for different pruning ratio
    for ratio in np.arange(0.5,1,2.5e-2):
        prune_net = deepcopy(net)
        mask_set = prune(prune_net, ratio)
        valid_loss, valid_accuracy = validation(prune_net, valid_loader, criterion, device)
        log_connect = torch.log(prune_net.connectivity()).cpu().detach().numpy()
        log_info = f'Pruning ratio:{ratio*100:.2f}% || valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy*100:.2f}% connectivity:{log_connect:.4e}'
        print(log_info)

        # fine-tuning
        optimizer_fine = optim.Adam(prune_net.parameters(), lr=1e-4)
        scheduler_fine = optim.lr_scheduler.StepLR(optimizer_fine, step_size=10, gamma=0.8)
        prune_net.train()
        for epoch in range(epochs_fine):
            pbar = tqdm(enumerate(train_loader), total=len(train_loader))
            for i, (inputs, labels) in pbar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer_fine.zero_grad()
                outputs = prune_net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                count = 0
                for module in prune_net.modules():
                    if module == prune_net.feature[1]:
                        continue
                    if isinstance(module, nn.BatchNorm2d):
                        module.weight.grad[mask_set[count]] = 0.
                        count += 1
                torch.nn.utils.clip_grad_norm_(prune_net.parameters(), max_norm=1.0)
                optimizer_fine.step()
            scheduler_fine.step()

        log_connect = torch.log(prune_net.connectivity()).cpu().detach().numpy()
        valid_loss, valid_accuracy = validation(prune_net, valid_loader, criterion, device)
        log_info = f'Fine-tuning || valid loss:{valid_loss:.4e} valid accuracy:{valid_accuracy*100:.2f}% connectivity:{log_connect:.4e}'
        print(log_info)


if __name__ == "__main__":
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    run(2, 1e-1, 1e-3)