import torch as t
from acdc.docstring.utils import get_all_docstring_things, get_docstring_subgraph_true_edges
from acdc.TLACDCEdge import TorchIndex

from hypo_interp.tasks.mech_interp_task import MechInterpTask
from hypo_interp.types_ import Circuit


class DocstringTask(MechInterpTask):
    """
    Docstring task from ...
    """

    def __init__(
        self,
        zero_ablation: bool = False,
        device: str = "cuda",
        seq_len: int = 50,
        num_examples: int = 100,
    ):
        """
        seq_len:
            Maximum length of the sequences to use in the dataset.
        num_examples:
            Number of examples to use in the dataset.
        """
        super().__init__(zero_ablation=zero_ablation, device=device)

        # load in a tl_model and grab some data
        all_docstring_things = get_all_docstring_things(
            num_examples=num_examples,
            seq_len=seq_len,
            metric_name="docstring_metric",
            device=device,
            correct_incorrect_wandb=False,
            return_one_element=True,
        )
        all_docstring_things_false = get_all_docstring_things(
            num_examples=num_examples,
            seq_len=seq_len,
            metric_name="docstring_metric",
            device=device,
            correct_incorrect_wandb=False,
            return_one_element=False,
        )
        self._validation_metric_per_prompt = all_docstring_things_false.validation_metric

        # Init abstract class attributes
        self._validation_metric = all_docstring_things.validation_metric
        self._ablate_dataset = all_docstring_things.validation_patch_data
        self._base_dataset = all_docstring_things.validation_data
        self._experiment = self._make_experiment(
            base_dataset=self._base_dataset,
            ablate_dataset=self._ablate_dataset,
            model=all_docstring_things.tl_model,
            validation_metric=self._validation_metric,
            zero_ablation=self._zero_ablation,
            use_pos_embed=self.use_pos_embed,
        )

        # Other attributes relevant for score

        self._validation_mask = all_docstring_things.validation_mask
        self._validate_attributes()

    def score(self, per_prompt: bool = False) -> t.Tensor:
        """
        Returns the score of the current circuit.
        """
        logits = self._experiment.model(self._base_dataset, return_type="logits")
        if per_prompt:

            return self._validation_metric_per_prompt(logits=logits), logits[:, 1:, :]
        else:
            return self._validation_metric(logits=logits), logits[:, 1:, :]

    @property
    def _canonical_circuit(self) -> Circuit:
        circuit: Circuit = list(get_docstring_subgraph_true_edges().items())
        circuit = [
            ((c[0][0], TorchIndex(c[0][1]), c[0][2], TorchIndex(c[0][3])), c[1])
            for c in circuit
        ]
        return circuit
