import os
import pickle

import torch as t
from acdc.induction.utils import get_all_induction_things

from hypo_interp.paths import ROOT
from hypo_interp.tasks.mech_interp_task import MechInterpTask
from hypo_interp.types_ import Circuit

##################
# Helper functions
##################


def get_true_circuit() -> Circuit:
    """
    Retrieves the true edges from the circuit.

    These were discovered using ACDC with zero ablation
    as specified in the paper and notebook and can be reproduced
    with the notebook in the acdc repo at
    acdc/notebooks/colabs/ACDC_Main_Demo.ipynb.
    """
    path = os.path.join(ROOT, "hypo_interp", "tasks", "induction", "canonical_edges.pkl")
    with open(path, "rb") as f:
        true_edges = pickle.load(f)
    return true_edges


#################
# Main class
#################


class InductionTask(MechInterpTask):
    """
    Induction task from the paper.
    """

    def __init__(
        self,
        zero_ablation: bool = True,
        device: str = "cpu",
        seq_len: int = 50,
        num_examples: int = 40,
    ):
        """
        seq_len:
            Maximum length of the sequences to use in the dataset.
        num_examples:
            Number of examples to use in the dataset.
        """
        super().__init__(zero_ablation=zero_ablation, device=device)

        # load in a tl_model and grab some data
        all_induction_things = get_all_induction_things(
            num_examples=num_examples,
            seq_len=seq_len,
            device=device,
            metric="nll",
        )

        # Init abstract class attributes
        self._validation_metric = all_induction_things.validation_metric
        self._ablate_dataset = all_induction_things.validation_patch_data
        self._base_dataset = all_induction_things.validation_data
        self._experiment = self._make_experiment(
            base_dataset=self._base_dataset,
            ablate_dataset=self._ablate_dataset,
            model=all_induction_things.tl_model,
            validation_metric=self._validation_metric,
            zero_ablation=self._zero_ablation,
            use_pos_embed=self.use_pos_embed,
        )

        # Other attributes relevant for score
        self._validation_mask = all_induction_things.validation_mask
        self._validation_labels = all_induction_things.validation_labels
        self._validate_attributes()

    def score(self, per_prompt: bool = False) -> t.Tensor:
        """
        Returns the score of the current circuit which is the NLL of the
        model where there is no induction.

        There might be ferwer prompts returned that there are in the dataset
        because some examples don't have an instance of induction.
        """
        logits, loss = self._experiment.model(
            self._base_dataset, return_type="both", loss_per_token=True
        )
        # We get the loss in the parts where induction is happening
        total_loss = (loss * self._validation_mask[:, :-1].int()).sum(dim=-1)
        avg_loss_per_prompt = total_loss / self._validation_mask[:, :-1].int().sum(dim=-1)

        nan_indices = t.isnan(avg_loss_per_prompt)

        if per_prompt:
            return avg_loss_per_prompt[~nan_indices], logits
        else:
            return avg_loss_per_prompt[~nan_indices].mean(), logits

    @property
    def _canonical_circuit(self) -> Circuit:
        circuit: Circuit = get_true_circuit()
        return circuit
