import torch 
from torch import nn
import torchvision

import numpy as np

from pruning import dataloading, utils, prune_loops
from classification import train
from tqdm import tqdm

from types import SimpleNamespace
import argparse

import os



def decor_prune_loop_whole_net_var_ratio_networks(model_factory, Rs, Ds, ratios, device, means=None, pivots=None, args=SimpleNamespace()):

    networks = []
    for r_indx, ratio in tqdm(enumerate(ratios)):
        net = model_factory()
        _ = net.to(device)
        
        prunable_layers = utils.replace_layers(net, Rs, means, pivots, args)

        for m_indx in range(1, len(prunable_layers)):
            m = utils.get_n_nodes_for_variance_cutoff(Ds[m_indx], ratio)

            mod_pre = prunable_layers[m_indx - 1]
            mod_post = prunable_layers[m_indx]
        
            n_inputs = len(mod_pre.layer.weight)
        
            indices = torch.arange(-m, 0).to(device) + n_inputs

            mod_post.prune_inputs(indices)
            mod_pre.prune_outputs(indices)

        networks.append(net)
    return networks

def prune_loop_whole_net_equal_amount(model_factory, Rs, ratios, device, means=None, pivots=None, args=SimpleNamespace()):

    networks = []
    for r_indx, ratio in tqdm(enumerate(ratios)):     
        net = model_factory()
        _ = net.to(device)   

        prunable_layers = utils.replace_layers(net, Rs, means, pivots, args)

        for m_indx in range(len(prunable_layers)-1):
            mod_pre = prunable_layers[m_indx]
            mod_post = prunable_layers[m_indx + 1]
            n_inputs = len(mod_pre.layer.weight)

            m = int(ratio * n_inputs)
            indices = torch.arange(-m, 0).to(device) + n_inputs

            mod_post.prune_inputs(indices)
            mod_pre.prune_outputs(indices)

        networks.append(net)
    return networks


def weirdcopy(net, incl_bias=True):  # incl_bias=False for resnet etc.
    ret = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
    pbs = utils.get_prunable_layers(ret)
    for i in range(len(pbs)):
        ret_layer = pbs[i]
        weight, bias = list(net.parameters())[(i*2):(i*2)+2]
        ret_layer.weight.data = weight.clone()
        ret_layer.bias.data = bias.clone()
        if isinstance(ret_layer, nn.Linear):
            ret_layer.in_features = weight.shape[1]
            ret_layer.out_features = weight.shape[0]
        else:
            ret_layer.in_channels = weight.shape[1]
            ret_layer.out_channels = weight.shape[0]

    return ret


parser = argparse.ArgumentParser(description="Pruning Parser", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, help="Model specifier, e.g. 'VGG16' or 'ResNet18'", required=True)
parser.add_argument("--output_path", default='./', help="Path to the output files")
parser.add_argument("--pruning_algo", type=str, help="Pruning algorithm specifier - 'dec'/'dec-dm'/'swm'", required=True)
parser.add_argument("--extra_args", type=str, default='', help="Additions such as 'shuffle-saw', 'whole-accwise', 'whole-sawwise', etc.")
parser.add_argument("--seed", type=int, default=0, help="RNG seed")
parser.add_argument("--ratio", type=float, required=True)
args = parser.parse_args()
config = vars(args)

args.prune_by_removal = args.model[:6] != 'resnet' and args.pruning_algo != 'dec-dm'
args.pre_dims=None

device = torch.device('cuda')
utils.fix_settings(seed=args.seed, fltype=torch.float32, allow_grad=True)
utils.make_paths(args)

args.model = args.model.lower()
if args.model == 'vgg11':
    net_factory = lambda: torchvision.models.vgg11(weights=torchvision.models.VGG11_Weights.DEFAULT)
elif args.model == 'vgg16':
    net_factory = lambda: torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
elif args.model == 'vgg19':
    net_factory = lambda: torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)
else:
    raise NotImplementedError()

pre_dims = utils.get_input_dims(device, args)
args.pre_dims = pre_dims

ratios = [args.ratio]
print(ratios)

means, Cs = utils.get_input_stats(None, None, device, False, args)  # setting net to None requires its input_stats to already exist


net = net_factory()
_ = net.to(device)

layerwise_ordering = utils.ZCA_ordering(Cs)[1:]

Cs = [C[order][:, order] for C, order in zip(Cs[1:], layerwise_ordering)]
Cs.insert( 0, Cs[0] )

inverse_ordering = utils.invert_ordering(layerwise_ordering)
layerwise_ordering.append(None)
inverse_ordering.insert(0, None)
pivots = list(zip(inverse_ordering, layerwise_ordering))

# layerwise_ordering = utils.C_ordering(Cs)[1:]
# Cs = [C[order][:, order] for C, order in zip(Cs[1:], layerwise_ordering)]
# Cs.insert( 0, Cs[0] )

# inverse_ordering = utils.invert_ordering(layerwise_ordering)
# layerwise_ordering.append(None)
# inverse_ordering.insert(0, None)
# pivots = list(zip(inverse_ordering, layerwise_ordering))
# net = net_factory()
# _ = net.to(device)

# layerwise_ordering = utils.swm_ordering()[:-1]
# inverse_ordering = [c.clone() for c in layerwise_ordering]
# layerwise_ordering.append(None)
# inverse_ordering.insert(0, None)
# pivots = list(zip(inverse_ordering, layerwise_ordering))

R_invs_Ds = [utils.ldl(C) for C in Cs]
R_invs, Ds = zip(*R_invs_Ds)

algo_extras = args.extra_args.split('-')
if 'uniform' in algo_extras:
    nets = prune_loop_whole_net_equal_amount(lambda: torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT), R_invs, ratios, device, None, pivots, args)
elif 'varwise' in algo_extras:
    nets = decor_prune_loop_whole_net_var_ratio_networks(
        lambda: torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT),
        R_invs, Ds, ratios, device, means, pivots, args
    )
else:
    print(f'incorrect algo_extra {args.extra_args}, {algo_extras}')
    exit()

fn = f'/vgg16/retraining-results/SNP-zca_ratio={args.ratio}_seed={args.seed}'
metric_file = fn+'.txt'

for ratio, net in zip(ratios, nets):
    pruned_net = weirdcopy(net)
    _ = pruned_net.to(device)

    train_loader, test_loader = dataloading.load_data(batch_size=256)
    loss, acc = utils.measure_perf(net, nn.CrossEntropyLoss(), test_loader, device)
    pc, flop = utils.get_network_stats(net, test_loader, device)
    print(f'ratio: {ratio} - acc: {acc}, loss: {loss}, flops: {flop}, pc: {pc}')

    train_args = train.get_args()
    train_args.metric_file = metric_file
    train_args.output_dir = fn + '/'
    os.mkdir(train_args.output_dir)
    print(train_args.output_dir)
    train.main(train_args, pruned_net)