import torch
from torch import nn
import torchvision

import torch_pruning as tp
from pruning import utils, dataloading
from classification import train

import numpy as np
import scipy

import os
import argparse

device = torch.device('cuda')

parser = argparse.ArgumentParser(description="Pruning Parser", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--order", type=str, help="SAW or ZCA")
parser.add_argument("--heuristic", type=str, help="var or equal")
parser.add_argument('--approximation', action=argparse.BooleanOptionalAction)
parser.add_argument("--x", type=float)
args = parser.parse_args()
config = vars(args)

MODEL_NAME = 'resnet50'
ORDERING  = config['order']
GLOBAL_HEURISTIC = config['heuristic']
APPROXIMATION = config['approximation']
X = config['x']

print(MODEL_NAME, ORDERING, GLOBAL_HEURISTIC, APPROXIMATION, X)

if MODEL_NAME not in ['resnet50', 'vgg16']:
    raise NotImplementedError(f'{MODEL_NAME} does not exist.')
if ORDERING not in ['SAW', 'ZCA']:
    raise NotImplementedError(f'{ORDERING} does not exist.')
if GLOBAL_HEURISTIC not in ['var', 'equal']:
    raise NotImplementedError(f'{GLOBAL_HEURISTIC} does not exist.')

utils.fix_settings(seed=13, fltype=torch.float32, allow_grad=True)


class ZCAImportance(tp.importance.Importance):
    def __call__(self, group, **kwargs):
        for dep, idxs in group:
            layer = dep.target.module
            if isinstance(layer, (nn.Linear, nn.Conv2d)) and (dep.handler.__name__ == 'prune_in_channels' or dep.handler.__name__ == 'prune_in_features'):
                if not hasattr(layer, 'C'): raise NotImplementedError("To compute ZCA importance each parameterized layer should hold a cross-correlation matrix 'C'")
                C = layer.C.clone()
                D = C.detach().cpu()
                D = scipy.linalg.sqrtm(D + np.eye(len(D)) * 1e-12)
                D = np.linalg.inv(D)
                D = np.diag(D)
                D = D * D
                D = 1/D
                D = torch.tensor(D).to(C.device)
                local_imp = D
                return local_imp
        return None
    
class VarImportance(tp.importance.Importance):
    def __call__(self, group, **kwargs):
        for dep, idxs in group:
            layer = dep.target.module
            if isinstance(layer, (nn.Linear, nn.Conv2d)) and (dep.handler.__name__ == 'prune_in_channels' or dep.handler.__name__ == 'prune_in_features'):
                if not hasattr(layer, 'D'): raise NotImplementedError("To compute ZCA importance each parameterized layer should hold a diagonal matrix 'D'")
                
                D = layer.D.detach().cpu().numpy()
                D =  D / np.sum(D)
                score = torch.tensor(
                        np.cumsum(D[::-1])[::-1].copy()
                    ).to(layer.weight.device)
                score = score[layer.pivots]  # revert ordering 
                return score
        return None

def ldl(C):
    LU, D, pivots = scipy.linalg.ldl(C.detach().cpu().numpy())
    LU = LU[pivots, :]
    D = D[pivots]
    R_inv = torch.tensor(LU).to(C.device)
    return R_inv, torch.tensor(np.diag(D), device=C.device)


ratio = np.round(X, decimals=2)

if MODEL_NAME == 'resnet50':
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1).to(device)
    _, Cs = torch.load('/PATH/TO/dec.pt', map_location=device)
elif MODEL_NAME == 'vgg16':
    model = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).to(device)
    _, Cs = torch.load('/PATH/TO/dec.pt', map_location=device)
else:
    raise NotImplementedError()
example_inputs = torch.randn(1, 3, 224, 224).to(device)

prunable_layers = list(filter(lambda m: isinstance(m[1], (nn.Linear, nn.Conv2d)), model.named_modules()))
for (name, layer), C in zip(prunable_layers, Cs):
    setattr(layer, 'C', C.clone())

DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

if ORDERING == 'SAW':
    imp = tp.importance.MagnitudeImportance(p=1, normalizer=None, group_reduction="first")
elif ORDERING == 'ZCA':
    imp = ZCAImportance()
else:
    raise NotImplementedError()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)

for group in reversed(list(DG.get_all_groups(ignored_layers=ignored_layers, root_module_types=[nn.Conv2d, nn.Linear]))):
    score = imp(group)
    order = torch.argsort(score, descending=True)
    order_inv = torch.argsort(order)

    for g in group:
        if isinstance(g[0].target.module, (nn.Linear, nn.Conv2d)):
            if g[0].handler.__name__ == 'prune_in_channels':
                setattr(g[0].target.module, 'pivots', order_inv)
                setattr(g[0].target.module, 'order', order)
            elif g[0].handler.__name__ == 'prune_in_features':
                setattr(g[0].target.module, 'pivots', order_inv)
                setattr(g[0].target.module, 'order', order)

for _, layer in prunable_layers[1:]:
    # pivots will be the inverse orders/pre-pivots obtained from the imp thingy
    # but post-pivots are also needed to adjust C
    M_inv, D = ldl(layer.C[layer.order][:, layer.order])
    if APPROXIMATION:
        setattr(layer, 'M_inv', M_inv)
    setattr(layer, 'D', D)

pruner = tp.pruner.MetaPruner(
    model,
    example_inputs,
    importance=VarImportance() if GLOBAL_HEURISTIC == 'var' else imp,
    pruning_ratio=ratio,
    ignored_layers=ignored_layers,
    global_pruning=GLOBAL_HEURISTIC == 'var'
)

pruner.step()
train_loader, test_loader = dataloading.load_data(batch_size=256)
loss, acc = utils.measure_perf(model, nn.CrossEntropyLoss(), test_loader, device)
pc, flop = utils.get_network_stats(model, test_loader, device)
print(f'ratio: {ratio} - acc: {acc}, loss: {loss}, flops: {flop}, pc: {pc}')

ap = 'approx_' if APPROXIMATION else ''
fn = f'./retraining-results/{MODEL_NAME}_{ORDERING}_{GLOBAL_HEURISTIC}_{ap}{X}x'
metric_file = fn + '.txt'
train_args = train.get_args()
train_args.metric_file = metric_file
train_args.output_dir = fn + '/'
train_args.device = device
if not os.path.exists(train_args.output_dir):
    os.mkdir(train_args.output_dir)
print(train_args.output_dir)
train.main(train_args, model)