import torch 
from torch import nn
import torchvision

import numpy as np

from pruning import dataloading, utils, prune_loops

import argparse

import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


parser = argparse.ArgumentParser(description="Pruning Parser", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, help="Model specifier, e.g. 'VGG16' or 'ResNet18'", required=True)
parser.add_argument("--device", type=str, help="PyTorch device", required=True)
parser.add_argument("--output_path", default='./', help="Path to the output files")
parser.add_argument("--pruning_algo", type=str, help="Pruning algorithm specifier - 'dec'/'dec-dm'/'saw'", required=True)
parser.add_argument("--extra_args", type=str, default='', help="Additions such as 'shuffle-saw', 'whole-accwise', 'whole-sawwise', etc.")
parser.add_argument("--seed", type=int, default=0, help="RNG seed")
parser.add_argument("--batch_size", default=128, type=int, help="Batchsize for measuring network stats and evaluation")
args = parser.parse_args()
config = vars(args)

args.prune_by_removal = args.model[:6] != 'resnet' and args.pruning_algo != 'dec-dm'


data_path = args.data_path
device = torch.device(args.device if torch.cuda.is_available else 'cpu')
fltype = torch.float32
bs = args.batch_size

utils.fix_settings(seed=args.seed, fltype=fltype)
utils.make_paths(args)

# pruning_ratios = np.arange(0, 1, 0.01)
pruning_ratios = np.arange(0, 1, 0.0125)
loss_fn = nn.CrossEntropyLoss()

train_dataloader, test_dataloader = dataloading.load_data(bs)

# model selection
args.model = args.model.lower()
if args.model == 'vgg11':
    net = torchvision.models.vgg11(weights=torchvision.models.VGG11_Weights.DEFAULT)
elif args.model == 'vgg16':
    net = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
elif args.model == 'vgg16-bn':
    net = torchvision.models.vgg16_bn(weights=torchvision.models.VGG16_BN_Weights.DEFAULT)
elif args.model == 'vgg19':
    net = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)
elif args.model == 'resnet18':
    net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
elif args.model == 'resnet50':
    net = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
elif args.model == 'resnet101':
    net = torchvision.models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT)
else:
    raise NotImplementedError()
_ = net.to(device)

algo = args.pruning_algo

algo_extras = args.extra_args.split('-')

pre_dims = utils.get_input_dims(device, args)
args.pre_dims = pre_dims

if algo == 'pfa_en':
    if args.model == 'vgg11':
        _, Cs = torch.load('/vgg11/stats/dec.pt', map_location=device)
    elif args.model == 'vgg19':
        _, Cs = torch.load('/vgg19/stats/dec.pt', map_location=device)
    elif args.model == 'vgg16':
        _, Cs = torch.load('/vgg16/stats/dec_dm.pt', map_location=device)
    elif args.model == 'resnet50':
        _, Cs = torch.load('/resnet50/stats/dec.pt', map_location=device)
    else:
        raise NotImplementedError()

    layerwise_ordering = utils.C_ordering(Cs)[1:]

    inverse_ordering = [c.clone() for c in layerwise_ordering]
    layerwise_ordering.append(None)
    inverse_ordering.insert(0, None)
    pivots = list(zip(inverse_ordering, layerwise_ordering))

    Ds = [
        torch.sort(
            torch.clamp_min(torch.linalg.eigh(C)[0], 0.),
            descending=False
        )[0] for C in Cs
    ]

    prune_loops.pfa_en_prune_loop(
        net,
        test_dataloader,
        Ds,
        pruning_ratios, 
        loss_fn,
        device, 
        pivots,
        args
    )


if algo == 'simple':
    path = args.stats_path + 'dec.pt'
    _, Cs = torch.load(path, map_location=device)

    pivots = None
    if 'shuffle' in algo_extras:
        if 'saw' in algo_extras:
            layerwise_ordering = utils.saw_ordering(net)[:-1]
        elif 'ZCA' in algo_extras:
            layerwise_ordering = utils.ZCA_ordering(Cs)[1:]
        elif 'rand' in algo_extras: 
            layerwise_ordering = utils.random_ordering(Cs)[1:]
        else:
            raise NotImplementedError()
        
        Cs = [C[order][:, order] for C, order in zip(Cs[1:], layerwise_ordering)]
        Cs.insert( 0, Cs[0] )
    
        inverse_ordering = utils.invert_ordering(layerwise_ordering)
        layerwise_ordering.append(None)
        inverse_ordering.insert(0, None)
        pivots = list(zip(inverse_ordering, layerwise_ordering))

    if 'whole' in algo_extras:  # random ordered whole network pruning
        prune_loops.prune_loop_whole_net_equal_amount(
            net, 
            test_dataloader,
            None, 
            pruning_ratios, 
            loss_fn,
            device, 
            None,
            pivots,
            args
        )

    else:  # random ordered layerwise pruning
        prune_loops.layerwise_prune_loop(
            net, 
            test_dataloader,
            None, 
            pruning_ratios, 
            loss_fn,
            device, 
            None,
            pivots,
            args
        )
elif algo == 'saw':
    layerwise_ordering = utils.saw_ordering(net)[:-1]
    inverse_ordering = [c.clone() for c in layerwise_ordering]
    layerwise_ordering.append(None)
    inverse_ordering.insert(0, None)
    pivots = list(zip(inverse_ordering, layerwise_ordering))

    if 'whole' in algo_extras:  # saw ordered whole network pruning
        prune_loops.prune_loop_whole_net_equal_amount(
            net, 
            test_dataloader,
            None, 
            pruning_ratios, 
            loss_fn,
            device, 
            None,
            pivots,
            args
        )

    else:  # saw ordered layerwise pruning
        prune_loops.layerwise_prune_loop(
            net, 
            test_dataloader,
            None, 
            pruning_ratios, 
            loss_fn,
            device, 
            None,
            pivots,
            args
        )
elif algo == 'dec' or algo == 'dec-dm':
    demean = algo == 'dec-dm'

    means, Cs = utils.get_input_stats(net, train_dataloader, device, demean, args)

    pivots = None
    if 'shuffle' in algo_extras:
        if 'saw' in algo_extras:
            layerwise_ordering = utils.saw_ordering(net)
        elif 'ZCA' in algo_extras:
            layerwise_ordering = utils.ZCA_ordering(Cs)[:1]
        elif 'rand' in algo_extras: 
            layerwise_ordering = utils.random_ordering(Cs)[1:]
        else:
            raise NotImplementedError()
        
        Cs = [C[order][:, order] for C, order in zip(Cs[1:], layerwise_ordering)]
        Cs.insert( 0, Cs[0] )
    
        inverse_ordering = utils.invert_ordering(layerwise_ordering)
        layerwise_ordering.append(None)
        inverse_ordering.insert(0, None)
        pivots = list(zip(inverse_ordering, layerwise_ordering))

    R_invs_Ds = [utils.ldl(C) for C in Cs]
    R_invs, Ds = zip(*R_invs_Ds)

    if 'whole' in algo_extras:
        if 'accwise' in algo_extras:
            whole_net_order = utils.network_wide_acc(net, np.loadtxt(args.results_path + 'dec--accuracies.txt'))
            prune_loops.decor_prune_loop_whole_net(
                net, 
                test_dataloader,
                R_invs, 
                pruning_ratios, 
                loss_fn,
                device, 
                whole_net_order[:, 0].astype(int),
                means,
                pivots,
                args
            )
        elif 'sawwise' in algo_extras:
            whole_net_order = utils.network_wide_saw(net)
            prune_loops.decor_prune_loop_whole_net(
                net, 
                test_dataloader,
                R_invs, 
                pruning_ratios, 
                loss_fn,
                device, 
                whole_net_order[:, 0].astype(int),
                means,
                pivots,
                args
            )
        elif 'varwise' in algo_extras:
            prune_loops.decor_prune_loop_whole_net_var_ratio(
                net,
                test_dataloader,
                R_invs, 
                Ds,
                pruning_ratios, 
                loss_fn,
                device, 
                means,
                pivots,
                args
            )
        elif 'uniform' in algo_extras:
            prune_loops.prune_loop_whole_net_equal_amount(
                net, 
                test_dataloader,
                R_invs, 
                pruning_ratios, 
                loss_fn,
                device, 
                means,
                pivots,
                args
            )
        else:
            raise NotImplementedError()

    else:  # decor: layerwise pruning possibly with shuffling
        prune_loops.layerwise_prune_loop(
            net, 
            test_dataloader,
            R_invs, 
            pruning_ratios, 
            loss_fn,
            device, 
            means,
            pivots,
            args
        )
elif algo == 'aw':
    prune_loops.weight_magnitude_prune(net, test_dataloader, pruning_ratios, loss_fn, device, args)
else:
    raise NotImplementedError(f"Selected pruning algorithm {args.pruning_algo} doesn't exist yet")
