import torch 
from torch import nn
import torchvision

import numpy as np

from pruning import dataloading, utils, prune_loops
from classification import train
from tqdm import tqdm

from types import SimpleNamespace
import argparse

import os


def get_prunable_layers(net):
    prunable_layers = []

    def _find_layers(module):
        for _, mod in module.named_children():
            if isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d):
                prunable_layers.append(mod)
            else:
                if mod.children() is not None: _find_layers(mod)

    _find_layers(net)
    return prunable_layers


def weirdcopy(net, incl_bias=True):  # incl_bias=False for resnet etc.
    ret = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
    pbs = utils.get_prunable_layers(ret)
    for i in range(len(pbs)):
        ret_layer = pbs[i]
        weight, bias = list(net.parameters())[(i*2):(i*2)+2]
        ret_layer.weight.data = weight.clone()
        ret_layer.bias.data = bias.clone()
        if isinstance(ret_layer, nn.Linear):
            ret_layer.in_features = weight.shape[1]
            ret_layer.out_features = weight.shape[0]
        else:
            ret_layer.in_channels = weight.shape[1]
            ret_layer.out_channels = weight.shape[0]

    return ret


parser = argparse.ArgumentParser(description="Pruning Parser", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, help="Model specifier, e.g. 'VGG16' or 'ResNet18'", required=True)
parser.add_argument("--output_path", default='./', help="Path to the output files")
parser.add_argument("--pruning_algo", type=str, help="Pruning algorithm specifier - 'dec'/'dec-dm'/'swm'", required=True)
parser.add_argument("--extra_args", type=str, default='', help="Additions such as 'shuffle-saw', 'whole-accwise', 'whole-sawwise', etc.")
parser.add_argument("--seed", type=int, default=0, help="RNG seed")
parser.add_argument("--ratio", type=float, required=True)
args = parser.parse_args()
config = vars(args)

args.prune_by_removal = args.model[:6] != 'resnet' and args.pruning_algo != 'dec-dm'
args.pre_dims=None

device = torch.device('cuda')
utils.fix_settings(seed=args.seed, fltype=torch.float32, allow_grad=True)
utils.make_paths(args)


net_factory = lambda: torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
net = net_factory()
_ = net.to(device)


ratios = [args.ratio]

_, Cs = torch.load(r'/PATH/TO/vgg16/stats/dec_dm.pt', map_location='cpu') # TODO

Ds = [
    torch.sort(
        torch.linalg.eigh(C)[0],
        descending=False   # these will be inverted lateron in the prune-loop while determining layer-wise ratios 
    )[0] for C in Cs
]

nets = []
for ratio in ratios:

    net = net_factory()

    rqs = []
    for D in Ds:
        D_norm = D / torch.sum(D)
        
        # get all cumulative sums from the right
        cumulative_sum = torch.cumsum(D_norm, dim=0)

        # select the first one exceeding the ratio
        required_count = np.argmax(cumulative_sum >= ratio)
        rqs.append(required_count.item())


    cpl = get_prunable_layers(net)
    for i, (C, rq) in enumerate(zip(Cs[1:], rqs[1:]), start=1):
        tl = cpl[i]
        ol = cpl[i-1]
        
        idcs = torch.argsort(torch.abs(C).sum(dim=0) / torch.diag(C), descending=True)
        idcs = idcs[rq:]

        w = tl.weight.data.clone()
        if i == 13:
            n_prev_output_channels = 512
            squared_kernel_size = w.shape[1] // n_prev_output_channels
            w = torch.reshape(w, (len(w), -1, squared_kernel_size))
            w = w[:, idcs]
            w = torch.flatten(w, start_dim=1, end_dim=2)
        else:
            w = w[:, idcs]
        tl.weight.data = w.clone()
        
        w = ol.weight.data.clone()
        b = ol.bias.data.clone()

        w = w[idcs]
        b = b[idcs]
        ol.weight.data = w.clone()
        ol.bias.data = b.clone()
        print(w.shape)

    _ = net.to(device)
    nets.append(net)

fn = f'/vgg16/retraining-results/PFA-EN_ratio={args.ratio}'
metric_file = fn+'.txt'

for ratio, net in zip(ratios, nets):
    pruned_net = weirdcopy(net)
    _ = pruned_net.to(device)

    train_loader, test_loader = dataloading.load_data(batch_size=256)
    loss, acc = utils.measure_perf(net, nn.CrossEntropyLoss(), test_loader, device)
    pc, flop = utils.get_network_stats(net, test_loader, device)
    print(f'ratio: {ratio} - acc: {acc}, loss: {loss}, flops: {flop}, pc: {pc}')

    train_args = train.get_args()
    train_args.metric_file = metric_file
    train_args.output_dir = fn + '/'
    os.mkdir(train_args.output_dir)
    print(train_args.output_dir)
    train.main(train_args, pruned_net)