import torch
import torch_pruning as tp
import numpy as np

def taylor_loss(model, loaders, gpu_id):
    loss_func = torch.nn.CrossEntropyLoss() 
    if gpu_id != -1:
        model.cuda(gpu_id)

    for images, labels in loaders["test"]:
        if gpu_id != -1:
            images, labels = images.cuda(), labels.cuda()
        test_output = model(images)

        loss = loss_func(test_output, labels)
        loss.backward()   

    model.cpu()

def prune(
    model,
    loaders,
    example_inputs,
    out_features,
    prune_type,
    gpu_id,
    sparsity=0.5,
    optimal_transport=None,
    backward_pruning=True,
    group_idxs=None,
    dimensionality_preserving=False
):
    print(f"Structured pruning with type {prune_type} and channel sparsity {sparsity}")
    ori_size = tp.utils.count_params(model)
    imp = None

    if prune_type == "random":
        imp = tp.importance.RandomImportance()
    elif prune_type == "l1":
        imp = tp.importance.MagnitudeImportance(1)
    elif prune_type == "l2":
        imp = tp.importance.MagnitudeImportance(2)
    elif prune_type == "l_inf":
        imp = tp.importance.MagnitudeImportance(np.inf)
    elif prune_type == "taylor":
        imp = tp.importance.TaylorImportance()
    elif prune_type == "lamp":
        imp = tp.importance.LAMPImportance()
    else:
        raise ValueError("Prune type not supported")

    ignored_layers = []

    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == out_features:
            ignored_layers.append(m)

    if next(model.parameters()).is_cuda:
        model.to("cpu")

    pruner = tp.pruner.MagnitudePruner(
        model,
        example_inputs,
        importance=imp,
        iterative_steps=1,  # number of iterations
        ch_sparsity=sparsity,  # channel sparsity
        ignored_layers=ignored_layers,  # ignored_layers will not be pruned
        optimal_transport=optimal_transport,
        backward_pruning=backward_pruning,
        dimensionality_preserving=dimensionality_preserving
    )

    if prune_type == "taylor":
        taylor_loss(model, loaders, gpu_id)

    pruner.step(group_idxs=group_idxs)

    return model
