import copy
from abc import abstractmethod, ABC

import torch
from torch import nn

from ..model.model import ExpandingLinear
from ..model.utils import get_model_last_layer, unfreeze_all


# def get_weights(model, layers, masks):
#     for layer in layers:
#         weights = model.__getatr__(layer)
#     last_layer = get_model_last_layer(model)
#     if len_choose is None:
#         return last_layer.weight_values
#     if len_choose == 0:
#         return torch.tensor([]) 
    
#     return last_layer.weight_values[-len_choose:] 

# def get_weights(model, len_choose):
#     last_layer = get_model_last_layer(model)
#     if len_choose is None:
#         return last_layer.weight_values
#     if len_choose == 0:
#         return torch.tensor([]) 
    
#     return last_layer.weight_values[-len_choose:] 


# def get_weights_grad(model, len_choose):
#     last_layer = get_model_last_layer(model)
#     grad = last_layer.weight_values.grad
    
#     if len_choose is None:
#         return grad
#     if len_choose == 0:
#         return torch.tensor([]) 
#     return grad[-len_choose:] 


class NonlinearityMetric(ABC):
    """
    Calculates a metric related to the nonlinearity of a model's layers.

        This class provides functionality to compute a value representing the
        nonlinearity exhibited by a given layer within a neural network model,
        potentially utilizing embeddings based on input parameters.
    """
    
    def __init__(self, loss_fn):
        """
        Initializes the LossWrapper with a given loss function.

            Args:
                loss_fn: The loss function to be wrapped.

            Returns:
                None
        """
        
        self.loss_fn = loss_fn

    @abstractmethod
    def calculate(self, model, layer,  add_embed, X_arr, y_arr, embed=False):
        """
        Calculates a value based on the given inputs.

            Args:
                model: The model to use for calculation.
                layer: The specific layer of the model to utilize.
                add_embed: A flag indicating whether to add embeddings.
                X_arr: Input data array.
                y_arr: Target data array.
                embed:  A boolean flag controlling embedding behavior (defaults to False).

            Returns:
                object: The calculated value. The type depends on the specific implementation.
        """
        
        pass

class AbsGradientEdgeMetric(NonlinearityMetric):
    """
    Calculates the absolute values of gradients for a given layer's weights."""
    
    def calculate(self, model, layer, X_arr, y_arr, embed=False):
        """
        Calculates the gradients of a specified layer's weights.

            Args:
                model: The neural network model to evaluate.
                layer: The layer whose gradients are to be calculated.
                X_arr: Input data for the model.
                y_arr: Target labels corresponding to the input data.
                embed: A boolean indicating whether to calculate gradients from the last embedding linear layer if True, otherwise calculates from the specified layer's weights.

            Returns:
                torch.Tensor: The absolute values of the gradients of the layer's weights.
        """
        
        # model = copy.deepcopy(model)
        unfreeze_all(model)
        model.eval()
        model.zero_grad()

        y_pred = model(X_arr).squeeze()
        loss = self.loss_fn(y_pred, y_arr)
        loss.backward()

        edge_gradients = layer.weight_values.grad.abs() if not embed else layer.embed_linears[-1].weight_values.grad.abs()
        model.zero_grad()
        return edge_gradients

class ReversedAbsGradientEdgeMetric(NonlinearityMetric):
    """
    Calculates edge gradients for a given layer in a neural network.

        The gradient is calculated as the inverse of the absolute gradient values,
        with a small epsilon added for numerical stability. This metric can be used to
        identify important edges in a model by quantifying their influence on the output.
    """
    
    def calculate(self, model, layer,  X_arr, y_arr, embed=False):
        """
        Calculates edge gradients for a given layer in a neural network.

            Args:
                model: The neural network model.
                layer: The layer for which to calculate gradients.
                X_arr: Input data array.
                y_arr: Target output array.
                embed: A boolean indicating whether the layer is an embedding layer.
                       If True, calculates gradients from the last linear layer of the embeddings.

            Returns:
                torch.Tensor: Edge gradients for the specified layer, calculated as the inverse of the absolute gradient values with a small epsilon added for numerical stability.
        """
        
        # model = copy.deepcopy(model)
        unfreeze_all(model)
        model.eval()
        model.zero_grad()

        y_pred = model(X_arr).squeeze()
        loss = self.loss_fn(y_pred, y_arr)
        loss.backward()

        edge_gradients = layer.weight_values.grad.abs() if not embed else layer.embed_linears[-1].weight_values.grad.abs()
        edge_gradients = 1 / (edge_gradients + 1e-8)
        model.zero_grad()
        return edge_gradients


# class SNIPMetric(NonlinearityMetric):
#     def calculate(self, model, X_arr, y_arr, len_choose): #todo len_choose
#         model = copy.deepcopy(model)
#         unfreeze_all(model)
#         model.eval()

#         for layer in model.modules():
#             if isinstance(layer, (nn.Linear, ExpandingLinear)):
#                 w = layer.weight if not isinstance(layer, ExpandingLinear) else layer.weight_values
#                 layer.weight_mask = nn.Parameter(torch.ones_like(w))
#                 if isinstance(layer, ExpandingLinear):
#                     nn.init.normal_(w, mean=0, std=0.01)
#                 else:
#                     nn.init.xavier_normal_(w)
#                 w.requires_grad = False

#         model.zero_grad()
#         outputs = model(X_arr).squeeze()
#         loss = self.loss_fn(outputs, y_arr)
#         loss.backward()

#         edge_gradients = get_model_last_layer(model).weight_mask.grad.abs()
#         model.zero_grad()
#         return edge_gradients


class MagnitudeL1Metric(NonlinearityMetric):
    """
    Calculates the L1 magnitude of weights for a given layer."""
    
    def calculate(self, model, layer,  X_arr=None, y_arr=None, embed=False):
        """
        Calculates the absolute values of weights for a given layer.

            Args:
                model: The model containing the layer. (Not used in calculation)
                layer: The layer whose weights are to be calculated.
                X_arr: Input data array. (Not used in calculation)
                y_arr: Target data array. (Not used in calculation)
                embed: A boolean indicating whether to access weights from the last embedding linear layer if True, otherwise uses the direct layer weights.

            Returns:
                torch.Tensor: The absolute values of the layer's weights or the last embedding linear layer's weights depending on the 'embed' flag.
        """
        
        return layer.weight_values.abs() if not embed else layer.embed_linears[-1].weight_values.abs()


class MagnitudeL2Metric(NonlinearityMetric):
    """
    Calculates the L2 norm of a layer's weights.

        This class provides a method to compute the squared magnitude of the weights
        of a given layer in a neural network model, optionally focusing on the final
        embedding linear layer if specified.
    """
    
    def calculate(self, model, layer,  X_arr=None, y_arr=None, embed=False):
        """
        Calculates the squared weight values of a given layer.

            Args:
                model: The model containing the layer. (Not used in calculation)
                layer: The layer whose weights are to be calculated.
                X_arr: Input data array. (Not used in calculation)
                y_arr: Target data array. (Not used in calculation)
                embed: A boolean indicating whether to calculate the squared weight values of the last embedding linear layer if True, otherwise calculates for the given layer directly.

            Returns:
                torch.Tensor: A tensor containing the squared weight values of the specified layer or its final embedding linear layer.
        """
        
        return torch.pow(layer.weight_values, 2) if not embed else torch.pow(layer.embed_linears[-1].weight_values, 2)


# class PerturbationSensitivityEdgeMetric(NonlinearityMetric):
#     def __init__(self, loss_fn, epsilon=1e-2):
#         super().__init__(loss_fn)
#         self.epsilon = epsilon

#     def calculate(self, model, X_arr, y_arr, len_choose): #todo len_choose
#         model.eval()
#         original_output = model(X_arr).detach()
#         last_layer = get_model_last_layer(model)
#         sensitivities = torch.zeros_like(last_layer.weight_values)

#         for idx in range(last_layer.weight_values.size(0)):
#             with torch.no_grad():
#                 original_value = last_layer.weight_values[idx].item()
#                 last_layer.weight_values[idx] += self.epsilon
#                 perturbed_output = model(X_arr)
#                 sensitivity = (perturbed_output - original_output).abs().mean().item()
#                 sensitivities[idx] = sensitivity
#                 last_layer.weight_values[idx] = original_value
#         return sensitivities
