import copy
import torchvision.transforms as transforms
from utils import evaluate, find_ignored_layers, get_cifar100_data_loader, get_cifar10_data_loader, get_pretrained_model
from prune import prune
from torch_pruning.optimal_transport import OptimalTransport
import torch_pruning as tp
import torch
from torchvision import datasets
import json

if __name__ == '__main__':

    dataset = "Cifar10"
    example_inputs = torch.randn(1, 3, 32, 32)
    out_features = 10 if dataset == "Cifar10" else 100
    gpu_id = 0
    backward_pruning = True
    model_name = "vgg11_bn"
    file_name = "./models/vgg11_bn_cifar10_0.checkpoint"


    config = dict(
        dataset=dataset,
        model=model_name,
    )

    prune_types = ["l1", "lamp", "taylor"]

    loaders = get_cifar10_data_loader() if dataset == "Cifar10" else get_cifar100_data_loader()

    model_original,_ = get_pretrained_model(config, file_name)

    DG = tp.DependencyGraph().build_dependency(model_original, example_inputs=example_inputs)

    ignored_layers = find_ignored_layers(model_original=model_original, out_features=out_features)
    num_groups = 0
    for group in DG.get_all_groups_in_order(ignored_layers=ignored_layers):
        num_groups += 1

    output_file_idx = 0
    output_file_name = f"{model_name}_{dataset}_{backward_pruning}_{prune_types[0]}_{output_file_idx}.json"
    
    ot = OptimalTransport(gpu_id=gpu_id)
    meta_pruning_types = [None, ot]
    sparsities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
    groups = [i for i in range(num_groups)]

    dict = {}
    dict["backward_pruning"] = backward_pruning # Important to know what indice of group means


    for prune_type in prune_types:
        dict[prune_type] = {}
        for group_idx in groups:
            dict[prune_type][group_idx] = {}
            for sparsity in sparsities:
                dict[prune_type][group_idx][sparsity] = {}
                for meta_prune in meta_pruning_types:
                    meta_prune_type = "default" if meta_prune == None else "IF"
                    pruned_model = copy.deepcopy(model_original)
                    prune(
                        pruned_model,
                        loaders,
                        example_inputs,
                        out_features,
                        prune_type,
                        gpu_id,
                        sparsity=sparsity,
                        optimal_transport=meta_prune,
                        backward_pruning=backward_pruning,
                        group_idxs=[group_idx],
                        dimensionality_preserving=False
                    )  
                    dict[prune_type][group_idx][sparsity][meta_prune_type] = evaluate(pruned_model, loaders, gpu_id=gpu_id)
                    
                    print(f"{prune_type} : {group_idx} : {sparsity} : {meta_prune_type} : {dict[prune_type][group_idx][sparsity][meta_prune_type]}")


                    with open(output_file_name, "w") as file:
                        json.dump(dict, file, indent=4)