import copy
import logging

import torch
import torch.nn as nn
from informal_criteria import winner_diff_criterion, same_winner_criterion, abs_max_criterion, abs_mean_criterion, winner_runner_criterion, compute_kl_divergence

PRUNE_CRITERIA_FUNCTIONS = {
    "same_winner": same_winner_criterion,
    "abs_mean": abs_mean_criterion,
    "abs_max": abs_max_criterion,
    "winner_diff": winner_diff_criterion,
    "winner_runner": winner_runner_criterion,
}

def informal_mnist_pruning(
        dataset,
        full_net,
        X,
        device,
        prune_by,
        metric,
        patching,
        delta,
        data_dist=None,
        weight_contrib=None,
        bottom_up=True,
        **kwargs
):
    logging.info(f"informal {dataset} pruning by: {prune_by}, patching: {patching}, metric: {metric}, delta: {delta}, bottom-up: {bottom_up}, device: {device}")

    if isinstance(metric, str):
        metric = PRUNE_CRITERIA_FUNCTIONS.get(metric)
        if metric is None:
            raise ValueError(f"metric '{metric}' is not defined in PRUNE_CRITERIA_FUNCTIONS.")
        
    if prune_by not in ['weights', 'neurons']:
        raise ValueError("Invalid value for prune_by. Use 'weights' or 'neurons'.")
    if patching not in ['zero', 'mean']:
        raise ValueError("Invalid patching method. Use 'zero' or 'mean'.")

    if patching == 'mean':
        if prune_by == 'weights' and weight_contrib is None:
            raise ValueError("patching='mean' requires weight_contrib for weight pruning.")
        if prune_by == 'neurons' and data_dist is None:
            raise ValueError("patching='mean' requires data_dist for neuron pruning.")

    dead_components = []
    active_components = []

    pruned_net = copy.deepcopy(full_net).to(device).eval()
    X = X.to(device)

    with torch.no_grad():
        baseline = pruned_net(X)

    layers = {name: layer for name, layer in pruned_net.named_modules() if isinstance(layer, nn.Linear)}
    layer_names = list(layers.keys())
    if bottom_up:
        layer_names.reverse()

    for name in layer_names:
        layer = layers[name]
        out_features, in_features = layer.out_features, layer.in_features

        if prune_by == 'weights':
            if patching == 'zero':
                for out_idx in range(out_features):
                    for in_idx in range(in_features):
                        old_val = layer.weight[out_idx, in_idx].clone()
                        layer.weight.data[out_idx, in_idx] = 0.0

                        with torch.no_grad():
                            new_val = pruned_net(X)

                        if not metric(baseline, new_val, delta):
                            layer.weight.data[out_idx, in_idx] = old_val

            elif patching == 'mean':
                contrib = weight_contrib[name]

                for out_idx in range(out_features):
                    for in_idx in range(in_features):
                        old_val = layer.weight[out_idx, in_idx].clone()
                        contrib_update = old_val * contrib[out_idx, in_idx].item()

                        layer.weight.data[out_idx, in_idx] = 0.0
                        layer.bias.data[out_idx] += contrib_update

                        with torch.no_grad():
                            new_val = pruned_net(X)

                        if not metric(baseline, new_val, delta):
                            layer.weight.data[out_idx, in_idx] = old_val
                            layer.bias.data[out_idx] -= contrib_update

        elif prune_by == 'neurons':
            for neuron_idx in range(out_features):
                old_weight = layer.weight[neuron_idx].clone()
                old_bias = layer.bias[neuron_idx].clone()

                if patching == 'zero':
                    layer.weight.data[neuron_idx, :] = 0.0
                    layer.bias.data[neuron_idx] = 0.0
                else:  # mean patching
                    layer.weight.data[neuron_idx, :] = 0.0
                    layer.bias.data[neuron_idx] = data_dist[name][neuron_idx].item()

                with torch.no_grad():
                    new_val = pruned_net(X)

                if not metric(baseline, new_val, delta):
                    layer.weight.data[neuron_idx] = old_weight
                    layer.bias.data[neuron_idx] = old_bias
                    logging.info(f"not pruned neuron {neuron_idx} in layer {name}")
                    active_components.append((name, neuron_idx))
                else:
                    dead_components.append((name, neuron_idx))
                    logging.info(f"pruned neuron {neuron_idx} in layer {name}")

    return pruned_net, {"active": active_components, "dead": dead_components, "granularity": prune_by}

## todo refactor this function to avoid duplicated long code
def informal_vision_pruning(
        dataset,
        full_net,
        X,
        device,
        prune_by,
        metric,
        patching,
        delta,
        data_dist=None,
        weight_contrib=None,
        bottom_up=True,
        verbose=True,
        **kwargs
):
    logging.info(f"informal {dataset} pruning by: {prune_by}, patching: {patching}, metric: {metric}, delta: {delta}, bottom-up: {bottom_up}, verbose: {verbose}, device: {device}")
    if isinstance(metric, str):
        metric = PRUNE_CRITERIA_FUNCTIONS.get(metric)
        if metric is None:
            raise ValueError(f"metric '{metric}' is not defined in PRUNE_CRITERIA_FUNCTIONS.")

    if prune_by not in ['conv_channels', 'conv_heads']: raise ValueError("Invalid value for prune_by. Use 'conv_channels', or 'conv_heads'.")
    if patching not in ['zero', 'mean']: raise ValueError("Invalid patching method. Use 'zero' or 'mean'.")
    if patching == 'mean' and prune_by != 'conv_channels': raise NotImplementedError("Patching='mean' is only implemented for CIFAR-10 conv_channels .")

    active_components = []
    dead_components = []
    pruned_net = copy.deepcopy(full_net).to(device).eval()
    X = X.to(device)

    with torch.no_grad():
        baseline = pruned_net(X)

    layers = {name: layer for name, layer in pruned_net.named_modules() if isinstance(layer, nn.Conv2d) and 'downsample' not in name.lower() and 'shortcut' not in name.lower()}
    layer_names = list(layers.keys())
    if bottom_up:
        layer_names.reverse()

    for name in layer_names:
        layer = layers[name]
        out_channels, in_channels, kernel_h, kernel_w = layer.weight.shape

        if prune_by == 'conv_channels':
            for out_idx in range(out_channels):
                old_weight = layer.weight[out_idx].clone()
                old_bias = layer.bias[out_idx].clone() if layer.bias is not None else None
                hook = None
                if patching == 'zero':
                    layer.weight.data[out_idx] = 0.0
                    if layer.bias is not None:
                        layer.bias.data[out_idx] = 0.0
                elif patching == 'mean':
                    mean_map = data_dist[name][out_idx] # (H, W)

                    # bind both idx and mean_map into the hook’s defaults
                    def hook_fn(module, inp, out, idx=out_idx, m=mean_map):
                        B, C, H, W = out.shape
                        out_clone = out.clone()
                        # m.shape == (H, W)
                        broadcast = m.to(out.device).unsqueeze(0)  # (1, H, W)
                        out_clone[:, idx, :, :] = broadcast.expand(B, H, W)
                        return out_clone

                    hook = layer.register_forward_hook(hook_fn)

                with torch.no_grad():
                    new_val = pruned_net(X)
                if not metric(baseline, new_val, delta):
                    if patching == 'zero':
                        layer.weight.data[out_idx] = old_weight  # Restore original weights if pruning is invalid
                        if layer.bias is not None:
                            layer.bias.data[out_idx] = old_bias
                    if patching == 'mean':
                        hook.remove()
                    if verbose: logging.info(f"[NOT PRUNED] Layer: {name}, Filter Index: {out_idx}")
                    active_components.append((name, out_idx))
                else:
                    if verbose: logging.info(f"[PRUNED] Layer: {name}, Filter Index: {out_idx}")
                    dead_components.append((name, out_idx))

        elif prune_by == 'conv_heads':
            # remove an entire conv-head (whole Conv2d) at once
            old_weight = layer.weight.clone()
            old_bias = layer.bias.clone() if layer.bias is not None else None

            if patching == 'zero':
                layer.weight.data[:] = 0.0 # zero out the entire conv-head
                if layer.bias is not None:
                    layer.bias.data[:] = 0.0

                with torch.no_grad():
                    new_val = pruned_net(X)

            # decide to keep or revert
            if not metric(baseline, new_val, delta):
                layer.weight.data[:] = old_weight
                if layer.bias is not None:
                    layer.bias.data[:] = old_bias
                if verbose: logging.info(f"[NOT PRUNED] Full conv head {name}")
                active_components.append(name)
            else:
                if verbose: logging.info(f"[PRUNED] Full conv head {name}")
                dead_components.append(name)

    return pruned_net, {"active": active_components, "dead": dead_components, "granularity": prune_by}

def generate_perturbations(X, epsilon, num_samples=50):
    """
    Parameters:
    - x (torch.Tensor): Original image tensor of shape [B, D] or [1, D].
    - epsilon (float): Maximum pixel-wise perturbation.
    - num_samples (int): Number of perturbed images to generate per input in the batch.

    Returns:
    - torch.Tensor: A tensor of perturbed images of shape [B * num_samples, D].
    """
    B, D = X.shape
    x_repeated = X.repeat_interleave(num_samples, dim=0)  # Shape: [B * num_samples, D]
    perturbations = torch.rand_like(x_repeated) * 2 * epsilon - epsilon  # Shape: [B * num_samples, D]
    perturbed_images = torch.clamp(x_repeated + perturbations, 0, 1)  # Shape: [B * num_samples, D]
    return perturbed_images



def informal_mnist_kl_div_pruning(
    full_net,
    x,
    device,
    tau,
    prune_by,
    patching,
    data_dist=None,
    weight_contrib=None,
    bottom_up=True,
    **kwargs):

    if prune_by not in ['weights', 'neurons']:
        raise ValueError("Invalid value for prune_by. Use 'weights' or 'neurons'.")
    if patching not in ['zero', 'mean']:
        raise ValueError("Invalid patching method. Use 'zero' or 'mean'.")

    active_components = []
    dead_components = []
    # validate required data for mean patching
    if patching == 'mean':
        if prune_by == 'weights' and weight_contrib is None:
            raise ValueError("patching='mean' requires weight_contrib for weight pruning.")
        if prune_by == 'neurons' and data_dist is None:
            raise ValueError("patching='mean' requires data_dist for neuron pruning.")

    # generate perturbed inputs
    perturbed_inputs = generate_perturbations(x, 0.3, 100).to(device) # todo make it configurable

    pruned_net = copy.deepcopy(full_net).to(device).eval()
    perturbed_inputs = perturbed_inputs.to(device)

    # compute baseline output distributions
    with torch.no_grad():
        baseline_outputs = pruned_net(perturbed_inputs)

    # automatically retrieve layers
    layers_order = [name for name, layer in full_net.named_modules() if isinstance(layer, nn.Linear) ]
    if bottom_up:
        layers_order = list(reversed(layers_order))

    for name in layers_order:
        layer = dict(pruned_net.named_modules())[name]
        out_features = layer.out_features
        in_features = layer.in_features

        if prune_by == 'weights':
            if patching == 'zero':
                for out_idx in range(out_features):
                    for in_idx in range(in_features):
                        old_val = layer.weight[out_idx, in_idx].clone()

                        with torch.no_grad():
                            cur_outputs = pruned_net(perturbed_inputs)
                        kl_div_cur = compute_kl_divergence(baseline_outputs, cur_outputs)

                        layer.weight.data[out_idx, in_idx] = 0.0

                        with torch.no_grad():
                            new_outputs = pruned_net(perturbed_inputs)
                        kl_div_new = compute_kl_divergence(baseline_outputs, new_outputs)

                        if kl_div_new - kl_div_cur < tau:
                            continue  # remove weight permanently
                        else:
                            layer.weight.data[out_idx, in_idx] = old_val  # revert weight

            elif patching == 'mean':
                contrib = weight_contrib[name]

                for out_idx in range(out_features):
                    for in_idx in range(in_features):
                        old_val = layer.weight[out_idx, in_idx].clone()
                        contrib_update = old_val * contrib[out_idx, in_idx].item()

                        with torch.no_grad():
                            cur_outputs = pruned_net(perturbed_inputs)
                        kl_div_cur = compute_kl_divergence(baseline_outputs, cur_outputs)

                        layer.weight.data[out_idx, in_idx] = 0.0
                        layer.bias.data[out_idx] += contrib_update

                        with torch.no_grad():
                            new_outputs = pruned_net(perturbed_inputs)
                        kl_div_new = compute_kl_divergence(baseline_outputs, new_outputs)

                        if kl_div_new - kl_div_cur < tau:
                            continue  # remove weight permanently
                        else:
                            layer.weight.data[out_idx, in_idx] = old_val  # revert weight
                            layer.bias.data[out_idx] -= contrib_update  # revert mean contribution

        elif prune_by == 'neurons':
            for neuron_idx in range(out_features):
                old_weight = layer.weight[neuron_idx].clone()
                old_bias = layer.bias[neuron_idx].clone()

                with torch.no_grad():
                    cur_outputs = pruned_net(perturbed_inputs)
                kl_div_cur = compute_kl_divergence(baseline_outputs, cur_outputs)

                if patching == 'zero':
                    layer.weight.data[neuron_idx, :] = 0.0
                    layer.bias.data[neuron_idx] = 0.0
                else: # mean patching
                    layer.weight.data[neuron_idx, :] = 0.0
                    layer.bias.data[neuron_idx] = data_dist[name][neuron_idx].item()

                with torch.no_grad():
                    new_outputs = pruned_net(perturbed_inputs)
                kl_div_new = compute_kl_divergence(baseline_outputs, new_outputs)

                if kl_div_new - kl_div_cur < tau:
                    dead_components.append((name, neuron_idx))
                    continue  # remove neuron permanently
                else:
                    layer.weight.data[neuron_idx] = old_weight  # revert weight
                    layer.bias.data[neuron_idx] = old_bias  # revert bias
                    active_components.append((name, neuron_idx))

    return pruned_net, {"active": active_components, "dead": dead_components, "granularity": prune_by}
