import os 
from collections import OrderedDict

import torch  
import torch.nn as nn
import models.wrn as resnet
from models.wrn import BasicBlock, LambdaLayer


# pruning layers
loc = 'cuda:1'
device = torch.device('cuda:1')
dataset = 'cifar10'
nets_name_path = {
                  'resnet56':'./results_resnet/save_resnet56/checkpoint.th',
                  }

logs = open('./logs_adaptivegrouplasso.txt', 'a+')

for net_name in nets_name_path.keys():
    logs.write('dataset:%s; \narch:%s; \nmodel path:%s \n'%(dataset, net_name, nets_name_path[net_name]))
    if not net_name.startswith('resnet'):
        continue
    if dataset == 'cifar10':
        model = resnet.__dict__[net_name]()
    elif dataset == 'cifar100':
        model = resnet.__dict__[net_name](num_classes=10)
    model.to(device)

    checkpoint = torch.load(nets_name_path[net_name], map_location=loc)
    new_checkpoint = OrderedDict()
    for k, v in checkpoint['state_dict'].items():
        if 'module' in k:
            k = k[7:]
            new_checkpoint[k] = v
        else:
            new_checkpoint[k] = v
    model.load_state_dict(new_checkpoint)

    s = 0
    block_weights = []

    for name, module in model.named_modules():

        if isinstance(module, BasicBlock):
            s +=1 
            block_penalty = 0.0
            block_params = 0
        if isinstance(module, nn.Conv2d):
            if s == 0:
                continue
            block_penalty += torch.sum(module.weight.pow(2))
            if module.bias is not None:
                block_penalty += torch.sum(module.bias.pow(2))    
        if 'shortcut' in name:
            if isinstance(module, LambdaLayer):
                s -= 1
            else:
                block_weights.append(torch.sqrt(block_penalty))

    block_weights = [round(float(i),2) for i in block_weights]
    logs.write(str(block_weights) +'\n\n')

