import torch
from nesim.utils.hook import ForwardHook
from nesim.utils.getting_modules import get_module_by_name
from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoConfig
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors

class CategorySelectivity:
    def __init__(
        self,
        dataset: dict,
        hook_outputs: dict,
    ):
        """
        calculates the category selectivity of a certain neuron idx over others
        """
        self.dataset = dataset
        self.hook_outputs = hook_outputs

        self.categories = list(self.dataset.keys())
        self.layer_names = list(self.hook_outputs.keys())

        self.valid_neuron_indices = range(
            0, self.hook_outputs[self.layer_names[0]][self.categories[0]][0].shape[-1]
        )

    def konkle(
        self,
        neuron_idx: int,
        target_class: str,
        other_classes: List[str],
        layer_name: str,
        mode = "norm"
    ):
        assert (
            neuron_idx in self.valid_neuron_indices
        ), f"Invalid neuron idx: {neuron_idx}"
        assert layer_name in self.layer_names, f"No layer exists {layer_name}"

        # TARGET ACTIVATIONS
        target_activations = []
        for idx, item in enumerate(self.hook_outputs[layer_name][target_class]):
            # now item is a tensor of shape [1, seq_len, emb]
            assert (
                item.ndim == 3
            ), f"Expected a tensot with 3 dims: (batch, seq, embedding)"
            assert (
                item.shape[0] == 1
            ), f"Expected a batch size of 1 but got: {item.shape[0]}"

            if mode == "norm":
                activation_single_datatset_item = item[:, :, neuron_idx].norm(dim=1)
            elif mode == "mean":
                activation_single_datatset_item = item[:, :, neuron_idx].mean()
            else:
                raise ValueError(f"Invalid mode: {mode}")

            target_activations.append(activation_single_datatset_item)

        # OTHER ACTIVATIONS
        other_activations = []
        for other_class in other_classes:
            for idx, item in enumerate(self.hook_outputs[layer_name][other_class]):
                # now item is a tensor of shape [1, seq_len, emb]
                assert (
                    item.ndim == 3
                ), f"Expected a tensot with 3 dims: (batch, seq, embedding)"
                assert (
                    item.shape[0] == 1
                ), f"Expected a batch size of 1 but got: {item.shape[0]}"

                if mode == "norm":
                    activation_single_datatset_item = item[:, :, neuron_idx].norm(dim=1)
                elif mode == "mean":
                    activation_single_datatset_item = item[:, :, neuron_idx].mean()
                else:
                    raise ValueError(f"Invalid mode: {mode}")
                other_activations.append(activation_single_datatset_item)

        mean_target_activation = sum(target_activations) 
        mean_other_activation = sum(other_activations)
        # raise AssertionError(mean_target_activation, mean_other_activation)
        # From Doshi and Konkle paper XXXX 7:48 mm:ss
        selectivity = (mean_target_activation - mean_other_activation) / (
            mean_target_activation + mean_other_activation
        )
        return selectivity.item()

    def softmax_score(self, neuron_idx: int, target_class: str, layer_name: str):
        assert (
            neuron_idx in self.valid_neuron_indices
        ), f"Invalid neuron idx: {neuron_idx}"
        assert layer_name in self.layer_names, f"No layer exists {layer_name}"
        # ALL ACTIVATIONS
        all_activations = []
        for category in self.categories:
            activations_single_class = []
            for idx, item in enumerate(self.hook_outputs[layer_name][category]):
                # now item is a tensor of shape [1, seq_len, emb]
                assert (
                    item.ndim == 3
                ), f"Expected a tensor with 3 dims: (batch, seq, embedding)"
                assert (
                    item.shape[0] == 1
                ), f"Expected a batch size of 1 but got: {item.shape[0]}"

                activation_single_datatset_item = item[:, :, neuron_idx].mean(dim=1)
                activations_single_class.append(activation_single_datatset_item)

            mean_activation_single_class = sum(activations_single_class) / len(
                activations_single_class
            )
            all_activations.append(mean_activation_single_class.item())

        assert len(all_activations) == len(self.categories)
        softmax_score = torch.tensor(all_activations).softmax(-1)[
            self.categories.index(target_class)
        ]
        return softmax_score.item()


def obtain_hook_outputs(
    model, layer_names: List[str], dataset: dict, tokenizer, device: str
):
    """
    schema -> Dict

    {
        "layer1": {
            "category1": List[TensorType] of len(dataset["category1"])
        }
    }
    """
    output_dict = {
        layer_name: {key: [] for key in dataset.keys()} for layer_name in layer_names
    }

    with torch.no_grad():
        # INITIATE HOOKS
        hooks = {}
        for layer_name in layer_names:
            hook_module = get_module_by_name(model, layer_name)
            hooks[layer_name] = ForwardHook(hook_module)

        # COMPUTE HOOKS AND STORE VALUES
        for key, input_sentence_list in dataset.items():
            for sentence in input_sentence_list:
                encoded_input = tokenizer(sentence, return_tensors="pt").to(device)
                output = model(**encoded_input)

                for layer_name, hook in hooks.items():
                    hook_activation = hook.output
                    output_dict[layer_name][key].append(hook_activation)

        for hook in hooks.values():
            hook.close()

    return output_dict


def hex_to_rgb(hex_color):
    """Convert hex color string to an RGB tuple."""
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) / 255.0 for i in (0, 2, 4))

def exponential_colormap(color: str, cmap_name="viridis", exponent=1.0):
    """
    Generate a colormap that transitions from a soft pastel blue (low) to the specified color (high).
    
    Parameters:
    - color: A hex color value for the high end of the colormap.
    - cmap_name: The base colormap to modify (default is "viridis").
    - exponent: The exponent to apply for the colormap transformation (default is 1.0).
    
    Returns:
    - A matplotlib ListedColormap with the specified transformation.
    """
    base_cmap = plt.get_cmap(cmap_name)
    colors = base_cmap(np.arange(256))
    x = np.linspace(0, 1, 256)
    
    # Convert the input color from hex to RGB
    high_rgb = np.array(hex_to_rgb(color))
    pastel_blue_rgb = np.array([173, 216, 230]) / 255.0  # Soft pastel blue
    
    # Apply the exponential scaling to transition from pastel blue to the input color
    colors[:, 0] = np.power(x, exponent) * high_rgb[0] + np.power(1 - x, exponent) * pastel_blue_rgb[0]
    colors[:, 1] = np.power(x, exponent) * high_rgb[1] + np.power(1 - x, exponent) * pastel_blue_rgb[1]
    colors[:, 2] = np.power(x, exponent) * high_rgb[2] + np.power(1 - x, exponent) * pastel_blue_rgb[2]
    
    return mcolors.ListedColormap(colors)

def apply_ratan_matplotlib_thing():
    import matplotlib as mpl
    from matplotlib import rcParams
    rcParams.update({'figure.autolayout': False})
    mpl.rcParams['pdf.fonttype'] = 42
    mpl.rcParams['ps.fonttype'] = 42