from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
from datasets import DatasetDict
from transformers import BatchEncoding, PreTrainedTokenizer


ContextType = Literal[
    "full_prefix",
    "local_prefix",
    "masked_prefix",
    "random_prefix",
    "in_alpha_prefix",
    "oo_alpha_prefix",
    "scattered_prefix",
    "all_random",
]


@dataclass
class ContextProbingConfig:
    seed: int
    prefix_length: int
    num_samples_per_string: int


@dataclass
class ProbingContexts:
    texts: list[str]
    encoded_contexts: BatchEncoding
    context_target_dixs: torch.Tensor
    target_tokens_ids: torch.Tensor
    target_tokens_idxs: list[int]


# noqa: C901
def construct_context(
    config: ContextProbingConfig,
    context_type: ContextType,
    tokenizer: PreTrainedTokenizer,
    dataset: DatasetDict,
    alphabet_token_ids: torch.Tensor,
    distinct_token_ids: torch.Tensor,
) -> ProbingContexts:
    rng = torch.Generator()
    rng.manual_seed(config.seed)

    data = dataset["test"]
    prefix_length = config.prefix_length

    probing_texts = []
    probing_input_ids = []
    probing_attention_masks = []
    probing_context_target_idxs = []
    probing_target_ids = []
    probing_target_idxs = []
    for input_ids, attention_mask in zip(
        data["input_ids"],
        data["attention_mask"],
    ):
        sequence_start = len(attention_mask) - attention_mask.sum()
        for target_idx in torch.randint(
            sequence_start + prefix_length,
            len(input_ids),
            (config.num_samples_per_string,),
            generator=rng,
        ):
            target_idx = int(target_idx.item())
            context_target_idx = target_idx
            prefix_start_idx = target_idx - prefix_length

            if context_type == "full_prefix":
                context_input_ids = input_ids
                context_attention_mask = attention_mask
            elif context_type == "local_prefix":
                context_input_ids = input_ids[prefix_start_idx : target_idx + 1]
                context_attention_mask = attention_mask[
                    prefix_start_idx : target_idx + 1
                ]
                context_target_idx = len(context_input_ids) - 1
            elif context_type == "masked_prefix":
                context_input_ids = input_ids
                context_attention_mask = torch.zeros_like(input_ids)
                # Mask out all previous tokens, i.e. only use the prefix
                # but keep the positions the same
                context_attention_mask[prefix_start_idx : target_idx + 1] = 1
                # context_attention_mask[target_idx + 1:] = 0
            elif context_type == "random_prefix":
                context_input_ids = input_ids.clone()
                context_attention_mask = attention_mask
                replacement_token_indices = torch.randint(
                    len(alphabet_token_ids),
                    (prefix_start_idx,),
                    generator=rng,
                )
                replacement_token_ids = alphabet_token_ids[
                    replacement_token_indices
                ]
                context_input_ids[:prefix_start_idx] = replacement_token_ids
            elif context_type == "in_alpha_prefix":
                # Set all tokens other than the prefix to be a constant
                # token chosen from the alphabet of the string, i.e. "a"
                context_attention_mask = attention_mask
                context_input_ids = input_ids.clone()
                context_input_ids[:prefix_start_idx] = alphabet_token_ids[0]
            elif context_type == "oo_alpha_prefix":
                # Set all tokens other than the prefix to be a constant
                # token chosen not from the alphabet of the string, i.e. "0"
                context_attention_mask = attention_mask
                context_input_ids = input_ids.clone()
                context_input_ids[:prefix_start_idx] = distinct_token_ids[0]
            elif context_type == "scattered_prefix":
                # Keep a prefix of prefix_length scattered randomly
                # throughout the string. Replace the other tokens with
                # randomly sampled ones.
                context_attention_mask = attention_mask
                context_input_ids = input_ids.clone()
                # Set all tokens up to the target token to random ones
                replacement_token_indices = torch.randint(
                    len(alphabet_token_ids),
                    (target_idx,),
                    generator=rng,
                )
                replacement_token_ids = alphabet_token_ids[
                    replacement_token_indices
                ]
                context_input_ids[:target_idx] = replacement_token_ids
                # Replace prefix_length tokens at random positions in the
                # prefix with their original values
                retain_idxs = torch.randperm(target_idx, generator=rng)[
                    :prefix_length
                ]
                context_input_ids[retain_idxs] = input_ids[retain_idxs]
            elif context_type == "all_random":
                # Baseline that replaces all preceding tokens with random
                # ones
                context_attention_mask = attention_mask
                replacement_token_indices = torch.randint(
                    len(alphabet_token_ids),
                    (len(input_ids),),
                    generator=rng,
                )
                context_input_ids = alphabet_token_ids[
                    replacement_token_indices
                ]
            else:
                raise ValueError(f"Unknown context type {context_type}")

            probing_texts.append(
                tokenizer.decode(input_ids[prefix_start_idx:target_idx])
            )
            probing_input_ids.append(context_input_ids)
            probing_attention_masks.append(context_attention_mask)
            probing_context_target_idxs.append(context_target_idx)
            probing_target_ids.append(input_ids[target_idx])
            probing_target_idxs.append(target_idx)
    return ProbingContexts(
        texts=probing_texts,
        encoded_contexts=BatchEncoding(
            {
                "input_ids": torch.stack(probing_input_ids),
                "attention_mask": torch.stack(probing_attention_masks),
            }
        ),
        context_target_dixs=torch.tensor(probing_context_target_idxs),
        target_tokens_ids=torch.tensor(probing_target_ids),
        target_tokens_idxs=probing_target_idxs,
    )
