import torch
import torch.nn as nn
import collections
from typing import List, Callable
import torch
import torch.nn.functional as F
import collections


def get_attributes(x: nn.Module, attributes: str):
    """
    gets a list of period-separated attributes
    i.e get_attributes(model, 'transformer.encoder.layer')
        should return the same as model.transformer.encoder.layer
    """
    for attr in attributes.split("."):
        x = getattr(x, attr)
    return x


def set_attribute_recursive(x: nn.Module, attributes: "str", new_attribute: nn.Module):
    """
    Given a list of period-separated attributes - set the final attribute in that list to the new value
    i.e set_attribute_recursive(model, 'transformer.encoder.layer', NewLayer)
        should set the final attribute of model.transformer.encoder.layer to NewLayer
    """
    for attr in attributes.split(".")[:-1]:
        x = getattr(x, attr)
    setattr(x, attributes.split(".")[-1], new_attribute)


def get_attn_layer(
    model: nn.Module,
    layer_idx: int,
    transformer_layers_attr,
    layer_attrs: str,
):
    """
    Gets the feedforward layer of a model within the transformer block
    `model`: torch.nn.Module
      a torch.nn.Module
    `layer_idx`: int
      which transformer layer to access
    `transformer_layers_attr`: str
      chain of attributes (separated by periods) that access the transformer layers within `model`.
      The transformer layers are expected to be indexable - i.e a Modulelist
    `layer_attrs`: str
      chain of attributes (separated by periods) that access the ff block within a transformer layer
    """
    transformer_layers = get_attributes(model, transformer_layers_attr)
    assert layer_idx < len(
        transformer_layers
    ), f"cannot get layer {layer_idx + 1} of a {len(transformer_layers)} layer model"
    ff_layer = get_attributes(transformer_layers[layer_idx], layer_attrs)
    return ff_layer


class AttentionPatch(nn.Module):
    """
    Patches an attention module to modify its behavior dynamically based on specified modifications.
    Intended to adjust attention scores or the parameters that influence attention scores.
    """

    def __init__(
            self,
            attention_module: nn.Module,
            synapse_modifications,
            mode: str = "enhance",
            enhance_value: float = 2.0,
    ):
        """
        Initializes the AttentionPatch module.

        :param attention_module: The specific attention module to be patched.
        :param synapse_modifications: A list of tuples specifying modifications. Each tuple should contain
                                      (layer_idx, head_idx, token_idx, modification_value).
        :param mode: The mode of modification ('enhance' to increase attention scores or 'suppress' to decrease them).
        """
        super().__init__()
        self.attention_module = attention_module
        self.synapse_modifications = synapse_modifications
        self.mode = mode
        self.enhance_value = enhance_value

    def forward(self, attention_scores):
        attention_scores  = self.attention_module(attention_scores)
        # Apply dynamic modifications based on synapse_modifications
        for layer_idx, head_idx, token_idx in self.synapse_modifications:
            if self.mode == "enhance":
                attention_scores[0, :, token_idx] *= self.enhance_value
            elif self.mode == "suppress":
                attention_scores[0, :, token_idx] = 0.0

        x = 0


def attn_patch_layer(
    model: nn.Module,
    transformer_layers_attr,
    layer_attrs,
    mode='suppress',
    synapses: List[List[int]] = None,
):
    """
    """
    transformer_layers = get_attributes(model, transformer_layers_attr)
    if mode in ["suppress", "enhance"]:
        synapses_dict = collections.defaultdict(list)
        for s in synapses:
            layer_idx,head_idx, pos = s
            synapses_dict[layer_idx].append(pos)
        for layer_idx, positions in synapses_dict.items():
            assert layer_idx < len(transformer_layers)
            layer = get_attributes(transformer_layers[layer_idx], layer_attrs)

            set_attribute_recursive(
                transformer_layers[layer_idx],
                layer_attrs,
                AttentionPatch(
                    layer,
                    mode=mode,
                    synapse_modifications=synapses,
                ),
            )
            x=0
    else:
        raise NotImplementedError


def attn_unpatch_layer(
    model: nn.Module,
    layer_idx: int,
    transformer_layers_attr,
    layer_attrs,
):
    """
    Removes the `MlpPatch` applied by `patch_ff_layer`, replacing it with its original value.

    `model`: torch.nn.Module
      a torch.nn.Module [currently only works with HF Bert models]
    `layer_idx`: int
      which transformer layer to access
    `transformer_layers_attr`: str
      chain of attributes (separated by periods) that access the transformer layers within `model`.
      The transformer layers are expected to be indexable - i.e a Modulelist
    `layer_attrs`: str
      chain of attributes (separated by periods) that access the ff block within a transformer layer
    """
    transformer_layers = get_attributes(model, transformer_layers_attr)
    assert layer_idx < len(
        transformer_layers
    ), f"cannot get layer {layer_idx + 1} of a {len(transformer_layers)} layer model"
    layer = get_attributes(transformer_layers[layer_idx], layer_attrs)
    assert isinstance(layer, AttentionPatch)
    set_attribute_recursive(
        transformer_layers[layer_idx],
        layer_attrs,
        layer.layer_function,
    )


def attn_unpatch_layers(
    model: nn.Module,
    layer_indices,
    transformer_layers_attr,
    layer_attrs,
):
    """
    Calls unpatch_ff_layer for all layers in layer_indices
    """
    for layer_idx in layer_indices:
        attn_unpatch_layer(model, layer_idx, transformer_layers_attr, layer_attrs)
