import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from prune import *
from Layers import layers


@torch.no_grad()
def summary(model, scores, flops, prunable):
    r"""Summary of compression results for a model.
    """
    rows = []
    for name, module in model.named_modules():
        for pname, param in module.named_parameters(recurse=False):
            pruned = prunable(module) and id(param) in scores.keys()
            if pruned:
                sparsity = getattr(module, pname + '_mask').detach().cpu().numpy().mean()
                score = scores[id(param)].detach().cpu().numpy()
            else:
                sparsity = 1.0
                score = np.zeros(1)
            shape = param.detach().cpu().numpy().shape
            flop = flops[name][pname]
            score_mean = score.mean()
            score_var = score.var()
            score_sum = score.sum()
            score_abs_mean = np.abs(score).mean()
            score_abs_var = np.abs(score).var()
            score_abs_sum = np.abs(score).sum()
            rows.append([
                name, pname, sparsity,
                np.prod(shape), shape, flop, score_mean, score_var, score_sum, score_abs_mean, score_abs_var, score_abs_sum, pruned
            ])

    columns = [
        'module', 'param', 'sparsity', 'size', 'shape', 'flops', 'score mean', 'score variance', 'score sum', 'score abs mean', 'score abs variance',
        'score abs sum', 'prunable'
    ]
    return pd.DataFrame(rows, columns=columns)


@torch.no_grad()
def flop(model, input_shape, device, pw_only=False):

    total = {}

    def count_flops(name):
        def hook(module, input, output):
            flops = {}
            if isinstance(module, layers.Linear) or isinstance(module, nn.Linear):
                in_features = module.in_features
                out_features = module.out_features
                flops['weight'] = in_features * out_features
                if module.bias is not None:
                    flops['bias'] = out_features
            if isinstance(module, layers.Conv2d) or isinstance(module, nn.Conv2d):
                if pw_only and module.kernel_size != (1, 1):
                    # print('skip non-pointwise conv', name)
                    pass
                else:
                    groups = module.groups
                    in_channels = module.in_channels // module.groups
                    out_channels = module.out_channels
                    kernel_size = int(np.prod(module.kernel_size))
                    output_size = output.size(2) * output.size(3)
                    flops['weight'] = in_channels * out_channels * kernel_size * output_size
                    if module.bias is not None:
                        flops['bias'] = out_channels * output_size
            if isinstance(module, layers.BatchNorm2d) or isinstance(module, nn.BatchNorm2d):
                output_size = output.size(2) * output.size(3)
                if module.affine:
                    flops['weight'] = module.num_features * output_size
                    flops['bias'] = module.num_features * output_size
            total[name] = flops

        return hook

    hooks = []
    for name, module in model.named_modules():
        hooks.append(module.register_forward_hook(count_flops(name)))

    input = torch.ones([1] + list(input_shape)).to(device)
    model(input)

    for hook in hooks:
        hook.remove()

    return total


def stats(model, skip_last):
    conv1x1, conv3x3, batchnorm, other_conv = 0, 0, 0, 0
    conv1x1_list = []
    for name, module in model.named_modules():
        if isinstance(module, layers.Conv2d):
            if module.kernel_size == (1, 1):
                conv1x1 += module.weight.numel()
                conv1x1_list.append(module.weight.numel())
            elif module.kernel_size == (3, 3):
                conv3x3 += module.weight.numel()
            else:
                print('Conv2d with kernel =', module.kernel_size, 'has', module.weight.numel(), 'params')
                other_conv += module.weight.numel()
        elif isinstance(module, layers.BatchNorm2d):
            batchnorm += module.weight.numel()
    if skip_last:
        conv1x1 -= conv1x1_list[-1]
        other_conv += conv1x1_list[-1]
    return conv1x1, conv3x3, batchnorm, other_conv


# def conservation(model, scores, batchnorm, residual):
#     r"""Summary of conservation results for a model.
#     """
#     rows = []
#     bias_flux = 0.0
#     mu = 0.0
#     for name, module in reversed(list(model.named_modules())):
#         if prunable(module, batchnorm, residual):
#             weight_flux = 0.0
#             for pname, param in module.named_parameters(recurse=False):

#                 # Get score
#                 score = scores[id(param)].detach().cpu().numpy()

#                 # Adjust batchnorm bias score for mean and variance
#                 if isinstance(module, (layers.Linear, layers.Conv2d)) and pname == "bias":
#                     bias = param.detach().cpu().numpy()
#                     score *= (bias - mu) / bias
#                     mu = 0.0
#                 if isinstance(module, (layers.BatchNorm1d, layers.BatchNorm2d)) and pname == "bias":
#                     mu = module.running_mean.detach().cpu().numpy()

#                 # Add flux
#                 if pname == "weight":
#                     weight_flux += score.sum()
#                 if pname == "bias":
#                     bias_flux += score.sum()
#             layer_flux = weight_flux
#             if not isinstance(module, (layers.Identity1d, layers.Identity2d)):
#                 layer_flux += bias_flux
#             rows.append([name, layer_flux])
#     columns = ['module', 'score flux']

#     return pd.DataFrame(rows, columns=columns)
