"""
This file contains the abstract class for a mech interp task.

The mech interp experiment class specifies all of the things required
to run a particular model and works as a thin wrapper around the
acdc.TLACDCExperiment class.
"""

import warnings
from functools import lru_cache
from typing import Union

import torch as t
from acdc.acdc_utils import TorchIndex
from acdc.TLACDCExperiment import TLACDCExperiment
from jaxtyping import Float
from torch import Tensor

import hypo_interp.utils.utils as utils
from hypo_interp.types_ import Circuit, EdgeType

#########################################
# Helper functions
#########################################


def _set_experiment_to_circuit(
    experiment: TLACDCExperiment, circuit: Circuit, invert: bool = False
):
    """
    Sets the experiment to the given circuit. This function changes
    the state of the experiment object.
    """

    # This adds every edge in the circuit to the model
    # see tutorial notebooks/colabs/ACDC_Editing_Edges_Demo.ipynb
    # in acdc repo for more info
    experiment.model.reset_hooks(including_permanent=True)

    experiment.setup_model_hooks(
        add_sender_hooks=True,
        add_receiver_hooks=True,
        doing_acdc_runs=False,
    )

    # set all edges to not inverted
    all_edges = experiment.corr.all_edges()
    for edge in all_edges.values():
        edge.present = invert

    for (n_to, idx_to, n_from, idx_from), _effect_size in circuit:
        n_to = n_to.replace("hook_resid_mid", "hook_mlp_in")
        n_from = n_from.replace("hook_resid_mid", "hook_mlp_in")

        idx_to = TorchIndex(idx_to) if isinstance(idx_to, tuple) else idx_to
        idx_from = TorchIndex(idx_from) if isinstance(idx_from, tuple) else idx_from
        all_edges[(n_to, idx_to, n_from, idx_from)].present = not invert

    return experiment


def _get_circuit_from_experiment(
    experiment: TLACDCExperiment, return_edge_types=False
) -> Circuit:
    edges_list = []
    edge_types_list = []
    for tup, e in experiment.corr.all_edges().items():
        if e.present or e.edge_type == EdgeType.PLACEHOLDER:
            edges_list.append((tup, e.effect_size))
            edge_types_list.append(e.edge_type)

    if return_edge_types:
        return edges_list, edge_types_list
    else:
        return edges_list


#########################################
# Main class
#########################################


class MechInterpTask:
    """
    Abstract class that establishes the configuration
    for a particular circuit.
    """

    def __init__(
        self,
        zero_ablation: bool,
        device: str,
        use_pos_embed: bool = False,
    ):
        """
        This method specifies what attributes should be set by the
        children of this class by the time the init method is done.
        Should be called at the start of the child class's init method.

        This can be done via the specified function in this method
        or by setting the attribute directly in the child class.

        zero_ablation:
            Whether to use zero ablation or ablate with random
            corrupted dataset.
        device:
            Device to run the experiment on.
        use_pos_embed:
            Whether to use positional embedding.
        """
        self.device = device
        self.use_pos_embed = use_pos_embed

        self._zero_ablation = zero_ablation
        self._validation_metric = None
        self._base_dataset = None
        self._ablate_dataset = None
        self._experiment: TLACDCExperiment = None

        # Doesn't need to be initialized until the circuit is set.
        self._circuit: Circuit = None

    @property
    def num_edges(self) -> int:
        """
        Returns the number of edges in the circuit.
        """
        if self._circuit is None:
            raise Exception("Circuit not set yet. Run `set_circuit` first.")

        return len(self._circuit)

    @property
    @lru_cache(maxsize=None)
    def complete_circuit(self) -> Circuit:
        """
        Returns the edges in the circuit of the full model.
        """
        dummy_experiment = self._make_experiment(
            self._base_dataset,
            self._ablate_dataset,
            self._experiment.model,
            self._validation_metric,
            self._zero_ablation,
            self.use_pos_embed,
        )
        return _get_circuit_from_experiment(dummy_experiment)

    @property
    @lru_cache(maxsize=None)
    def edge_to_type(self):
        """
        Returns a dictionary mapping edges to their type. Edges are specified
        in the form
        (out_node_name, out_shape, in_node_name, in_shape) where out_shape and in_shape
        are TorchIndex objects.
        """
        dummy_experiment = self._make_experiment(
            self._base_dataset,
            self._ablate_dataset,
            self._experiment.model,
            self._validation_metric,
            self._zero_ablation,
            self.use_pos_embed,
        )
        edge_to_type = {}
        circuit, edge_types = _get_circuit_from_experiment(
            dummy_experiment, return_edge_types=True
        )

        for edge_tuple, edge_type in zip(circuit, edge_types):
            edge = edge_tuple[0]  # (edge, effect_size)
            out_node_name, out_shape, in_node_name, in_shape = edge
            out_shape = (
                out_shape if type(out_shape) is TorchIndex else TorchIndex(out_shape)
            )
            in_shape = in_shape if type(in_shape) is TorchIndex else TorchIndex(in_shape)
            edge_tuple = (out_node_name, out_shape, in_node_name, in_shape)
            edge_to_type[edge_tuple] = edge_type
        return edge_to_type

    @property
    @lru_cache(maxsize=None)
    def canonical_circuit(self) -> Circuit:
        """
        Gets the canonical circuit according to the
        paper in question. Some MechInterpExperiments
        may not have canonical circuits.

        This is different from _canonical_circuit in that
        we have some additional logic to get all of the edges
        in the circuit rather than just the ones that are
        specified by the saved version (acdc shenanigans).
        """
        dummy_experiment = self._make_experiment(
            self._base_dataset,
            self._ablate_dataset,
            self._experiment.model,
            self._validation_metric,
            self._zero_ablation,
            self.use_pos_embed,
        )
        # We get a version of the canonical circuit that
        # might not have all of the edges
        incomplete_canonical_circuit = self._canonical_circuit
        dummy_experiment = _set_experiment_to_circuit(
            dummy_experiment, incomplete_canonical_circuit, invert=False
        )
        complete_canonical_circuit = _get_circuit_from_experiment(dummy_experiment)
        return complete_canonical_circuit

    @property
    def _canonical_circuit(self) -> Circuit:
        """
        Gets the canonical circuit according to the
        paper in question. Some MechInterpExperiments
        may not have canonical circuits. It doesn't matter
        if the canconical circuit is incomplete or not.
        """
        raise NotImplementedError

    def reset(self) -> None:
        """
        Resets the circuit to the complete circuit.
        """
        self._experiment = self._make_experiment(
            self._base_dataset,
            self._ablate_dataset,
            self._experiment.model,
            self._validation_metric,
            self._zero_ablation,
            self.use_pos_embed,
        )
        self._circuit = self.complete_circuit

    def set_circuit(self, circuit: Circuit, invert: bool = False):
        if self._experiment is None:
            raise Exception(
                "Experiment not initialized yet. Run `_init_experiment` first."
            )

        # Next line for tracking purposes. It doesn't
        # actually set the circuit in the model.
        self._circuit = circuit

        self._experiment = _set_experiment_to_circuit(
            self._experiment, circuit, invert=invert
        )

    def _make_experiment(
        self,
        base_dataset,
        ablate_dataset,
        model,
        validation_metric,
        zero_ablation,
        use_pos_embed,
    ):
        """
        Initializes the experiment object and adds all hooks
        relevant to the model.
        """
        experiment = TLACDCExperiment(
            ds=base_dataset,
            ref_ds=ablate_dataset,
            threshold=0.0,
            model=model,
            metric=validation_metric,
            zero_ablation=zero_ablation,
            use_pos_embed=use_pos_embed,
            verbose=False,
        )
        return utils.add_all_hooks(experiment)

    def _validate_attributes(self):
        """
        Checks that everything that should be set by the
        end of the init method is set.
        """
        is_valid = True
        is_valid &= self._zero_ablation is not None
        is_valid &= self._validation_metric is not None
        is_valid &= self._base_dataset is not None
        is_valid &= self._ablate_dataset is not None
        is_valid &= self._experiment is not None

        if not is_valid:
            raise Exception("Not all attributes are set.")

    def score(self, per_prompt: bool = False, **kwargs):
        raise NotImplementedError

    def eval_metric(
        self,
        original_score: Union[Float[Tensor, "batch"], Float[Tensor, ""]],  # noqa
        candidate_score: Union[Float[Tensor, "batch"], Float[Tensor, ""]],  # noqa
        use_mean: bool = False,
        distance: str = "l2",
    ) -> t.Tensor:
        """
        Returns the squared difference between the original score and the candidate score.
        """
        # Check that if per prompt then the batch dimension is the same

        if use_mean:
            warnings.warn("we are first taking the mean, then take the difference")
            candidate_score = candidate_score.float().mean()
            original_score = original_score.float().mean()
        if not use_mean:
            assert candidate_score.shape[0] == original_score.shape[0]

        # Compute the difference

        if distance == "l2":
            diff = (candidate_score - original_score) ** 2
        elif distance == "l1":
            diff = t.abs(candidate_score - original_score)
        else:
            raise ValueError(f"Unknown distance {distance}, try input l1 or l2")
        return diff.float().mean()
