"""Tools for mechanistic analysis of components."""
import dataclasses
import re
from typing import Callable, FrozenSet, List, Union

import numpy as np
import tensorflow as tf
from transformers import PreTrainedTokenizer

from em.fishers import per_example
from em.tools.clustering import vat
from em.tools.nmf import nmf_common
from em.util import flat_pack

from em.analysis import bert_nmf_analysis2 as bna2

###############################################################################

_PURE_DIGITS_TOKEN_REGEX = re.compile(r'^#?#?\d+$')


class ExampleRegexes:
    """Do not instantiate this class."""
    LITERAL_PRIMALITY_REGEX = re.compile(r'^is (?P<number>\d+)(?: a)? (?P<adj>composite|prime)(?: number)?\?$')

###############################################################################


@dataclasses.dataclass
class MechanisticContext:
    """Basically a wrapper stuff for all the information we'll need."""

    decomp: nmf_common.NmfDecomposition
    pe_fishers_data: per_example.PerExampleFlatFishers

    tokenizer: PreTrainedTokenizer

    finetuned_variables: List[tf.Variable]

    def __post_init__(self):
        self.packer = flat_pack.FlatPacker([v.shape for v in self.finetuned_variables])
        self.localizer = bna2.ComponentLocalizationInfo(
            variables=self.finetuned_variables)

        # Set up some attributes that will be used to cache stuff later.
        self._pure_digit_token_ids = None

    @property
    def n_components(self) -> int:
        return self.decomp.W.shape[1]

    @property
    def n_examples(self) -> int:
        return self.decomp.W.shape[0]

    ###########################################################################

    def get_pure_digits_token_ids(self) -> FrozenSet[int]:
        # Returns a list of ids of tokens containing purely digits.
        if self._pure_digit_token_ids is not None:
            return self._pure_digit_token_ids

        token_ids = []
        for token, token_id in self.tokenizer.vocab.items():
            if re.search(_PURE_DIGITS_TOKEN_REGEX, token):
                token_ids.append(token_id)

        self._pure_digit_token_ids = frozenset(token_ids)
        return self._pure_digit_token_ids

    ###########################################################################

    def get_component_top_example_indices(self, component_index: int, n_examples: int) -> np.ndarray:
        _, inds = tf.math.top_k(self.decomp.W[:, component_index], k=n_examples)
        return inds.numpy()

    def get_example_str(self, example_index: int) -> str:
        pe_fishers_data, tokenizer = self.pe_fishers_data, self.tokenizer
        example = tokenizer.decode(pe_fishers_data.input_ids[example_index])
        example = example.replace(tokenizer.pad_token, '')
        example = example.replace(tokenizer.cls_token, '')
        example = example.replace(tokenizer.sep_token, '')
        example = example.strip()
        return example

    ###########################################################################

    def get_fraction_of_example_masses_on_components(self, example_indices, component_indices):
        coeffs = self.decomp.W[example_indices]
        total_masses = coeffs.sum(axis=-1)
        component_masses = coeffs[:, component_indices].sum(axis=-1)
        return component_masses / (total_masses + 1e-12) 

    ###########################################################################

    def compute_vat_permutation(self, component_indices=None):
        # If component_indices are not None, then the permutation
        # will only be for the component_indices indexed from 0 to 1 - len(component_indices).
        if component_indices is None:
            H = self.decomp.H
        else:
            H = self.decomp.H[component_indices]

        cos_sim_matrix = H @ H.T
        cos_dissim_matrix = 1 - cos_sim_matrix
        _, permutation = vat.vat_reorder_dissimilarity_matrix(cos_dissim_matrix)
        return permutation

    ###########################################################################

    def get_components_matching_regex(
        self,
        regex: Union[str, re.Pattern],
        check_top_k: int,
        min_match_fraction: float = 1.0,
    ) -> List[int]:
        if isinstance(regex, str):
            regex = re.compile(regex)
        cond_fn = lambda _, ind: re.search(regex, self.get_example_str(ind))
        return self.get_components_by_top_examples(
            cond_fn,
            check_top_k=check_top_k,
            min_match_fraction=min_match_fraction,
        )

    def get_components_by_top_examples(
        self,
        condition_fn: Callable[['MechanisticContext', int], bool],
        check_top_k: int,
        min_match_fraction: float = 1.0,
    ) -> List[int]:
        # condition_fn takes in the instance of this class and the example index.
        # It should return true if the example "matches" the condition.

        _, all_top_inds = tf.math.top_k(self.decomp.W.T, k=check_top_k)

        matching_comp_inds = []
        for i in range(self.n_components):
            comp_top_inds = all_top_inds[i]
            count = 0
            for ind in comp_top_inds:
                if condition_fn(self, ind):
                    count += 1

            # Handle this case specially due to paranoia about division weirdness.
            if min_match_fraction == 1.0 and count == check_top_k:
                matching_comp_inds.append(i)

            elif count / check_top_k >= min_match_fraction:
                matching_comp_inds.append(i)

        return matching_comp_inds

    ###########################################################################

    def _get_preprocessed_H(self, component_indices=None, reorder=True):
        if component_indices is None:
            H = self.decomp.H
        else:
            H = self.decomp.H[component_indices]

        if reorder:
            permutation = self.compute_vat_permutation(component_indices)
            H = H[permutation]

        return H

    def compute_similarity_matrix(self, component_indices=None, reorder=True):
        H = self._get_preprocessed_H(component_indices, reorder)
        # Assumes that H has been normalized such that components have unit norm.
        return H @ H.T

    def compute_fractions_per_variable(self, component_indices=None, reorder=True):
        H = self._get_preprocessed_H(component_indices, reorder)
        fracs_per_var = []
        for i in range(H.shape[0]):
            fracs_per_var.append(self.localizer.fraction_per_variable(H[i]))
        return np.array(fracs_per_var)

    def print_top_examples(
        self,
        component: int,
        n_examples: int,
    ):
        pe_fishers_data = self.pe_fishers_data
        _, inds = tf.math.top_k(self.decomp.W[:, component], k=n_examples)
        for ind in inds:
            label = pe_fishers_data.labels[ind]
            prediction = np.argmax(pe_fishers_data.predicted_logits[ind])
            example = self.get_example_str(ind)
            print(f'{label}, {prediction}: {example}')
