import torch 
from torch import nn
import torchvision

import numpy as np

from pruning import dataloading, utils, prune_loops
from classification import train
from tqdm import tqdm

from types import SimpleNamespace
import argparse

import os



def decor_prune_loop_whole_net_var_ratio_networks(model_factory, Rs, Ds, ratios, device, means=None, pivots=None, args=SimpleNamespace()):

    networks = []
    for r_indx, ratio in tqdm(enumerate(ratios)):
        net = model_factory()
        _ = net.to(device)
        
        prunable_layers = utils.replace_layers(net, Rs, means, pivots, args)

        for m_indx in range(1, len(prunable_layers)):
            m = utils.get_n_nodes_for_variance_cutoff(Ds[m_indx], ratio)

            mod_pre = prunable_layers[m_indx - 1]
            mod_post = prunable_layers[m_indx]
        
            n_inputs = len(mod_pre.layer.weight)
        
            indices = torch.arange(-m, 0).to(device) + n_inputs

            mod_post.prune_inputs(indices)
            mod_pre.prune_outputs(indices)

        networks.append(net)
    return networks

def prune_loop_whole_net_equal_amount(model_factory, Rs, ratios, device, means=None, pivots=None, args=SimpleNamespace()):

    networks = []
    for r_indx, ratio in tqdm(enumerate(ratios)):     
        net = model_factory()
        _ = net.to(device)   

        prunable_layers = utils.replace_layers(net, Rs, means, pivots, args)

        for m_indx in range(len(prunable_layers)-1):
            mod_pre = prunable_layers[m_indx]
            mod_post = prunable_layers[m_indx + 1]
            n_inputs = len(mod_pre.layer.weight)

            m = int(ratio * n_inputs)
            indices = torch.arange(-m, 0).to(device) + n_inputs

            mod_post.prune_inputs(indices)
            mod_pre.prune_outputs(indices)

        networks.append(net)
    return networks


def weirdcopy_sd(sd, incl_bias=True):  # incl_bias=False for resnet etc.
    ret = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
    pbs = utils.get_prunable_layers(ret)
    keys = list(sd.keys())
    for i in range(len(pbs)):
        ret_layer = pbs[i]
        weight_key, bias_key = keys[(i*2):(i*2)+2]
        weight, bias = sd[weight_key], sd[bias_key]
        ret_layer.weight.data = weight.clone()
        ret_layer.bias.data = bias.clone()
        if isinstance(ret_layer, nn.Linear):
            ret_layer.in_features = weight.shape[1]
            ret_layer.out_features = weight.shape[0]
        else:
            ret_layer.in_channels = weight.shape[1]
            ret_layer.out_channels = weight.shape[0]

    return ret

def weirdcopy(net, incl_bias=True):  # incl_bias=False for resnet etc.
    ret = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
    pbs = utils.get_prunable_layers(ret)
    for i in range(len(pbs)):
        ret_layer = pbs[i]
        weight, bias = list(net.parameters())[(i*2):(i*2)+2]
        ret_layer.weight.data = weight.clone()
        ret_layer.bias.data = bias.clone()
        if isinstance(ret_layer, nn.Linear):
            ret_layer.in_features = weight.shape[1]
            ret_layer.out_features = weight.shape[0]
        else:
            ret_layer.in_channels = weight.shape[1]
            ret_layer.out_channels = weight.shape[0]

    return ret

parser = argparse.ArgumentParser(description="Pruning Parser", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, help="Model specifier, e.g. 'VGG16' or 'ResNet18'", required=True)
parser.add_argument("--output_path", default='./', help="Path to the output files")
parser.add_argument("--pruning_algo", type=str, help="Pruning algorithm specifier - 'dec'/'dec-dm'/'swm'", required=True)
parser.add_argument("--extra_args", type=str, default='', help="Additions such as 'shuffle-saw', 'whole-accwise', 'whole-sawwise', etc.")
parser.add_argument("--seed", type=int, default=0, help="RNG seed")
parser.add_argument("--ratio", type=float, required=True)
args = parser.parse_args()
config = vars(args)

args.prune_by_removal = args.model[:6] != 'resnet' and args.pruning_algo != 'dec-dm'
args.pre_dims=None

device = torch.device('cuda')
utils.fix_settings(seed=args.seed, fltype=torch.float32, allow_grad=True)
utils.make_paths(args)


net_factory = lambda: torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)

ratios = [args.ratio]
print(ratios)


net = net_factory()
# pre-prune net
sd = torch.load(f'PATH/TO/thinet/model-cr={args.ratio}.pt', map_location=device)
print(sd)
net = weirdcopy_sd(sd)
_ = net.to(device)
nets = [net]

fn = f'/vgg16/retraining-results/ThiNet_ratio={args.ratio}'
metric_file = fn+'.txt'

for ratio, net in zip(ratios, nets):
    pruned_net = weirdcopy(net)
    _ = pruned_net.to(device)

    train_loader, test_loader = dataloading.load_data(batch_size=256)
    loss, acc = utils.measure_perf(net, nn.CrossEntropyLoss(), test_loader, device)
    pc, flop = utils.get_network_stats(net, test_loader, device)
    print(f'ratio: {ratio} - acc: {acc}, loss: {loss}, flops: {flop}, pc: {pc}')

    train_args = train.get_args()
    train_args.metric_file = metric_file
    train_args.output_dir = fn + '/'
    os.mkdir(train_args.output_dir)
    print(train_args.output_dir)
    train.main(train_args, pruned_net)