import torch
from acdc.tracr_task.utils import (
    get_all_tracr_things,
    get_tracr_proportion_edges,
    get_tracr_reverse_edges,
)

from hypo_interp.tasks.mech_interp_task import MechInterpTask
from hypo_interp.types_ import Circuit


##################
# Helper functions
##################
def get_reverse_circuit() -> Circuit:
    """
    just loading the true circuit from acdc.tracr.utils
    """
    d_trues = get_tracr_reverse_edges()
    true_edges = []
    for k, v in d_trues.items():
        true_edges.append((k, v))
    return true_edges


def get_proportion_circuit() -> Circuit:
    """
    just loading the true circuit from acdc.tracr.utils
    """
    d_trues = get_tracr_proportion_edges()
    true_edges = []
    for k, v in d_trues.items():
        true_edges.append((k, v))
    return true_edges


class _TracrTask(MechInterpTask):
    """
    Tracr tasks from the paper.
    """

    def __init__(
        self,
        task: str,
        zero_ablation: bool = False,
        device: str = "cuda",
        seq_len: int = 50,
        num_examples: int = 6,
        metric: str = "l2",
    ):
        """
        seq_len:
            Maximum length of the sequences to use in the dataset.
        num_examples:
            Number of examples to use in the dataset.
        """
        super().__init__(
            zero_ablation=zero_ablation,
            device=device,
            use_pos_embed=True,
        )

        # load in a tl_model and grab some data
        all_things = get_all_tracr_things(
            task=task,
            metric_name=metric,
            num_examples=num_examples,
            device=device,
        )

        # Init abstract class attributes
        self._validation_metric = all_things.validation_metric
        self._ablate_dataset = all_things.validation_patch_data
        self._base_dataset = all_things.validation_data
        self._experiment = self._make_experiment(
            base_dataset=self._base_dataset,
            ablate_dataset=self._ablate_dataset,
            model=all_things.tl_model,
            validation_metric=self._validation_metric,
            zero_ablation=self._zero_ablation,
            use_pos_embed=self.use_pos_embed,
        )

        # Other attributes relevant for score
        self._validation_mask = all_things.validation_mask
        self._validate_attributes()

    def score(self, per_prompt: bool = False) -> torch.Tensor:
        """
        Returns the score of the current circuit.
        """
        logits = self._experiment.model(self._base_dataset)
        loss = self._validation_metric(logits, return_one_element=False)
        if per_prompt:
            bs = self._base_dataset.shape[0]
            return loss.view(bs, -1).mean(-1), logits[:, 1:, :]
        return loss.mean(), logits[:, 1:, :]


#################
# Main classes
#################
class TracrReverseTask(_TracrTask):
    """
    Tracr reverse task from the paper.
    """

    def __init__(
        self,
        zero_ablation: bool = False,
        device: str = "cuda",
        seq_len: int = 50,
        num_examples: int = 6,
        metric: str = "l2",
    ):
        """
        seq_len:
            Maximum length of the sequences to use in the dataset.
        num_examples:
            Number of examples to use in the dataset.
        """
        super().__init__(
            task="reverse",
            zero_ablation=zero_ablation,
            device=device,
            seq_len=seq_len,
            num_examples=num_examples,
            metric=metric,
        )

    @property
    def canonical_circuit(self) -> Circuit:
        circuit: Circuit = get_reverse_circuit()
        return circuit


class TracrProportionTask(_TracrTask):
    """
    Tracr proportion task from the paper.
    """

    def __init__(
        self,
        zero_ablation: bool = False,
        device: str = "cuda",
        seq_len: int = 50,
        num_examples: int = 50,
        metric: str = "l2",
    ):
        """
        seq_len:
            Maximum length of the sequences to use in the dataset.
        num_examples:
            Number of examples to use in the dataset.
        """
        super().__init__(
            task="proportion",
            zero_ablation=zero_ablation,
            device=device,
            seq_len=seq_len,
            num_examples=num_examples,
            metric=metric,
        )

    @property
    def canonical_circuit(self) -> Circuit:
        circuit: Circuit = get_proportion_circuit()
        return circuit
