import torch
import torch.nn as nn
import torch_pruning as tp
import numpy as np

def global_calibration(args, model_ori, model_p, dataloader, groups_ori, groups_p, keep_idxs, device):

    def cache_output(name, outputs):
        def tmp(layer, inp, out):
                outputs[name] = out
        return tmp

    output_ori = {}
    handles_ori = []
    visited_ori_layer = {}
    for group in groups_ori:
        for dep, _ in group:
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_conv_in_channels, tp.prune_linear_in_channels, tp.prune_batchnorm_out_channels] \
            and dep.target._name not in visited_ori_layer.keys():
                handles_ori.append(
                    dep.target.module.register_forward_hook(cache_output(dep.target._name, output_ori))
                )
                visited_ori_layer[dep.target._name] = 1
    
    output_p = {}
    handles_p = []
    visited_out = {}
    visited_p_layer = {}
    for group in groups_p:
        for dep, idx in group:
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_conv_in_channels, tp.prune_linear_in_channels, tp.prune_batchnorm_out_channels] \
            and dep.target._name not in visited_p_layer.keys():
                handles_p.append(
                    dep.target.module.register_forward_hook(cache_output(dep.target._name, output_p))
                )
                visited_p_layer[dep.target._name] = 1
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_batchnorm_out_channels] and dep.target._name not in visited_out.keys():
                name = dep.target._name
                visited_out[name] = 1
                keep_idxs[name] = list(set(keep_idxs[name]) - set([keep_idxs[name][j] for j in idx]))
                keep_idxs[name].sort()
        group.prune()

    criterion = nn.MSELoss(reduction='sum')
    param_group = []
    num_layers = len(visited_p_layer)
    for name, m in model_p.named_modules():
        if (isinstance(m, (nn.Conv2d, nn.Linear)) or (isinstance(m, (nn.BatchNorm2d)) and m.affine==True)) and name in visited_p_layer.keys():
            param_group.append({'params':m.parameters(), 'lr':args.lr/num_layers})
            num_layers -= 1
    optimizer = torch.optim.Adam(param_group, lr=args.lr)
    #optimizer = torch.optim.Adam(model_p.parameters(), lr=args.lr)

    model_p.train()
    for _ in range(args.iters_c):
        for i, batch in enumerate(dataloader):
            image = batch[0].to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                model_ori(image)
            model_p(image)
            loss = 0
            for name in output_p:
                norm = torch.norm(output_ori[name][:,keep_idxs[name]]).item() ** 2
                loss += criterion(output_p[name], output_ori[name][:,keep_idxs[name]]) / norm
            loss.backward()
            optimizer.step()

    for h in handles_p:
        h.remove()
    for h in handles_ori:
        h.remove()


def layer_calibration(args, model_ori, model_p, dataloader, groups_ori, groups_p, keep_idxs, input_idxs, device):

    def cache_output(name, outputs):
        def tmp(layer, inp, out):
                outputs[name] = out
        return tmp

    def cache_input(name, inputs):
        def tmp(layer, inp, out):
                inputs[name] = inp[0]
        return tmp

    output_ori = {}
    handles_ori = []
    handles_p = []
    input_ori = {}
    visited_ori_layer = {}
    for group in groups_ori:
        for dep, _ in group:
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_conv_in_channels, tp.prune_linear_in_channels, tp.prune_batchnorm_out_channels] \
            and dep.target._name not in visited_ori_layer.keys():
                handles_ori.append(
                    dep.target.module.register_forward_hook(cache_output(dep.target._name, output_ori))
                )
                handles_p.append(
                    dep.target.module.register_forward_hook(cache_input(dep.target._name, input_ori))
                )
                visited_ori_layer[dep.target._name] = 1
 
    optimizers = {}
    layers = {}
    visited_out = {}
    visited_inp = {}
    for group in groups_p:
        for dep, idx in group:
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_conv_in_channels, tp.prune_linear_in_channels] \
            and dep.target._name not in layers.keys():
                name = dep.target._name
                layers[dep.target._name] = dep.target.module
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_batchnorm_out_channels] and dep.target._name not in visited_out.keys():
                name = dep.target._name
                visited_out[name] = 1
                keep_idxs[name] = list(set(keep_idxs[name]) - set([keep_idxs[name][j] for j in idx]))
                keep_idxs[name].sort()
            elif dep.handler in [tp.prune_conv_in_channels, tp.prune_linear_in_channels, tp.prune_batchnorm_in_channels] and dep.target._name not in visited_inp.keys():
                name = dep.target._name
                visited_inp[name] = 1
                input_idxs[name] = list(set(input_idxs[name]) - set([input_idxs[name][j] for j in idx]))
                input_idxs[name].sort()
        group.prune()

    criterion = nn.MSELoss(reduction='sum')
    size = args.batch_size


    for name in layers.keys():
        optimizers[name] = torch.optim.Adam([layers[name].weight], lr=args.lr) 

    for _ in range(args.iters_c):
        for i, batch in enumerate(dataloader):
            image = batch[0].to(device)
            with torch.no_grad():
                model_ori(image)
            for name in layers.keys():
                optimizers[name].zero_grad()
                with torch.enable_grad():
                    output = layers[name](input_ori[name][:,input_idxs[name]])
                norm = torch.norm(output_ori[name][:,keep_idxs[name]]).item() ** 2
                loss = criterion(output, output_ori[name][:,keep_idxs[name]]) / norm
                loss.backward()
                optimizers[name].step()

    for h in handles_p:
        h.remove()
    for h in handles_ori:
        h.remove()	
