import torch
import torch.nn as nn
import torch_pruning as tp
import copy

def sparse_train(args, model_ori, model_p, dataloader, groups_ori, groups_p, keep_idxs, device):

    def cache_output(name, outputs):
        def tmp(layer, inp, out):
                outputs[name] = out
        return tmp

    output_ori = {}
    handles_ori = []
    visited_ori_layer = {}
    for group in groups_ori:
        for dep, _ in group:
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_conv_in_channels, tp.prune_linear_in_channels, tp.prune_batchnorm_out_channels] \
            and dep.target._name not in visited_ori_layer.keys():
                handles_ori.append(
                    dep.target.module.register_forward_hook(cache_output(dep.target._name, output_ori))
                )
                visited_ori_layer[dep.target._name] = 1
                
    output_p = {}
    handles_p = []
    visited_out = {}
    visited_inp = {}
    visited_p_layer = {}
    out_pruned = {}
    input_pruned = {}
    for group in groups_p:
        for dep, idx in group:
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_conv_in_channels, tp.prune_linear_in_channels, tp.prune_batchnorm_out_channels] \
            and dep.target._name not in visited_p_layer.keys():
                handles_p.append(
                    dep.target.module.register_forward_hook(cache_output(dep.target._name, output_p))
                )
                visited_p_layer[dep.target._name] = 1
            if dep.handler in [tp.prune_conv_out_channels, tp.prune_linear_out_channels, tp.prune_batchnorm_out_channels] and dep.target._name not in visited_out.keys():
                name = dep.target._name
                visited_out[name] = 1
                out_pruned[name] = idx
            if dep.handler in [tp.prune_conv_in_channels, tp.prune_linear_in_channels, tp.prune_batchnorm_in_channels] and dep.target._name not in visited_inp.keys():
                name = dep.target._name
                visited_inp[name] = 1
                input_pruned[name] = idx

    for akey in visited_p_layer.keys():
        if akey not in out_pruned:
            out_pruned[akey] = []
        if akey not in input_pruned:
            input_pruned[akey] = []
    
    criterion = nn.MSELoss(reduction='sum')
    optimizer = torch.optim.Adam(model_p.parameters(), lr=args.lr)

    model_p.train()
    for j in range(args.iters_s):
        for i, batch in enumerate(dataloader):
            image = batch[0].to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                model_ori(image)
            model_p(image)
            loss = 0
            for name in output_p:
                norm = torch.norm(output_ori[name]).item() ** 2
                loss += criterion(output_p[name], output_ori[name][:,keep_idxs[name]]) / norm
            loss.backward()
            for name, m in model_p.named_modules():
                if isinstance(m, (nn.Conv2d, nn.Linear)) and name in visited_p_layer.keys():
                    m.weight.grad.data[out_pruned[name]] = (args.reg + j * args.inc) * m.weight.data[out_pruned[name]] * args.lr
                    m.weight.grad.data[:,input_pruned[name]] = (args.reg + j * args.inc) * m.weight.data[:,input_pruned[name]] * args.lr
                elif isinstance(m, (nn.BatchNorm2d)) and m.affine==True and name in visited_p_layer.keys():
                    m.weight.grad.data[out_pruned[name]] = (args.reg + j * args.inc) * m.weight.data[out_pruned[name]] * args.lr
            optimizer.step()

    model_p.eval()
    for h in handles_p:
        h.remove()
    for h in handles_ori:
        h.remove()
