from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
from datasets import DatasetDict
from transformers import BatchEncoding, PreTrainedTokenizer


ContextType = Literal[
    "full_prefix",
    "local_prefix",
    "masked_prefix",
    "random_prefix",
    "in_alpha_prefix",
    "oo_alpha_prefix",
    "scattered_prefix",
    "all_random",
]


@dataclass
class ContextProbingConfig:
    seed: int
    prefix_length: int
    num_samples_per_string: int


@dataclass
class ProbingContexts:
    texts: list[str]
    encoded_contexts: BatchEncoding
    context_target_dixs: torch.Tensor
    target_tokens_ids: torch.Tensor
    target_tokens_idxs: list[int]


# noqa: C901
def construct_context(
    config: ContextProbingConfig,
    context_type: ContextType,
    tokenizer: PreTrainedTokenizer,
    dataset: DatasetDict,
    alphabet_token_ids: torch.Tensor,
    distinct_token_ids: torch.Tensor,
) -> ProbingContexts:
    rng = torch.Generator()
    rng.manual_seed(config.seed)

    data = dataset["test"]
    prefix_length = config.prefix_length

    probing_texts = []
    probing_input_ids = []
    probing_attention_masks = []
    probing_context_target_idxs = []
    probing_target_ids = []
    probing_target_idxs = []
    for input_ids, attention_mask in zip(
        data["input_ids"],
        data["attention_mask"],
    ):
        sequence_start = len(attention_mask) - attention_mask.sum()
        for target_idx in torch.randint(
            sequence_start + prefix_length,
            len(input_ids),
            (config.num_samples_per_string,),
            generator=rng,
        ):
            target_idx = int(target_idx.item())
            context_target_idx = target_idx
            prefix_start_idx = target_idx - prefix_length

            if context_type == "full_prefix":
                context_input_ids = input_ids
                context_attention_mask = attention_mask
            elif context_type == "local_prefix":
                context_input_ids = input_ids[prefix_start_idx : target_idx + 1]
                context_attention_mask = attention_mask[
                    prefix_start_idx : target_idx + 1
                ]
                context_target_idx = len(context_input_ids) - 1
            elif context_type == "masked_prefix":
                context_input_ids = input_ids
                context_attention_mask = torch.zeros_like(input_ids)
                # Mask out all previous tokens, i.e. only use the prefix
                # but keep the positions the same
                context_attention_mask[prefix_start_idx : target_idx + 1] = 1
                # context_attention_mask[target_idx + 1:] = 0
            elif context_type == "random_prefix":
                context_input_ids = input_ids.clone()
                context_attention_mask = attention_mask
                replacement_token_indices = torch.randint(
                    len(alphabet_token_ids),
                    (prefix_start_idx,),
                    generator=rng,
                )
                replacement_token_ids = alphabet_token_ids[
                    replacement_token_indices
                ]
                context_input_ids[:prefix_start_idx] = replacement_token_ids
            elif context_type == "in_alpha_prefix":
                # Set all tokens other than the prefix to be a constant
                # token chosen from the alphabet of the string, i.e. "a"
                context_attention_mask = attention_mask
                context_input_ids = input_ids.clone()
                context_input_ids[:prefix_start_idx] = alphabet_token_ids[0]
            elif context_type == "oo_alpha_prefix":
                # Set all tokens other than the prefix to be a constant
                # token chosen not from the alphabet of the string, i.e. "0"
                context_attention_mask = attention_mask
                context_input_ids = input_ids.clone()
                context_input_ids[:prefix_start_idx] = distinct_token_ids[0]
            elif context_type == "scattered_prefix":
                # Keep a prefix of prefix_length scattered randomly
                # throughout the string. Replace the other tokens with
                # randomly sampled ones.
                context_attention_mask = attention_mask
                context_input_ids = input_ids.clone()
                # Set all tokens up to the target token to random ones
                replacement_token_indices = torch.randint(
                    len(alphabet_token_ids),
                    (target_idx,),
                    generator=rng,
                )
                replacement_token_ids = alphabet_token_ids[
                    replacement_token_indices
                ]
                context_input_ids[:target_idx] = replacement_token_ids
                # Replace prefix_length tokens at random positions in the
                # prefix with their original values
                retain_idxs = torch.randperm(target_idx, generator=rng)[
                    :prefix_length
                ]
                context_input_ids[retain_idxs] = input_ids[retain_idxs]
            elif context_type == "all_random":
                # Baseline that replaces all preceding tokens with random
                # ones
                context_attention_mask = attention_mask
                replacement_token_indices = torch.randint(
                    len(alphabet_token_ids),
                    (len(input_ids),),
                    generator=rng,
                )
                context_input_ids = alphabet_token_ids[
                    replacement_token_indices
                ]
            else:
                raise ValueError(f"Unknown context type {context_type}")

            probing_texts.append(
                tokenizer.decode(input_ids[prefix_start_idx:target_idx])
            )
            probing_input_ids.append(context_input_ids)
            probing_attention_masks.append(context_attention_mask)
            probing_context_target_idxs.append(context_target_idx)
            probing_target_ids.append(input_ids[target_idx])
            probing_target_idxs.append(target_idx)
    return ProbingContexts(
        texts=probing_texts,
        encoded_contexts=BatchEncoding(
            {
                "input_ids": torch.stack(probing_input_ids),
                "attention_mask": torch.stack(probing_attention_masks),
            }
        ),
        context_target_dixs=torch.tensor(probing_context_target_idxs),
        target_tokens_ids=torch.tensor(probing_target_ids),
        target_tokens_idxs=probing_target_idxs,
    )


def sample_random_string_replacement(
    length: int,
    alphabet: str,
    seed: int,
) -> str:
    replacement_token_ids = sample_random_replacement(
        length,
        len(alphabet),
        seed,
    )
    replacement_string = "".join([alphabet[i] for i in replacement_token_ids])
    return replacement_string


def sample_random_replacement(
    length: int,
    alphabet_token_ids: int | list[int] | np.ndarray,
    seed: int,
) -> np.ndarray:
    rng = np.random.default_rng(seed)
    replacement_token_ids = rng.choice(
        alphabet_token_ids,
        (length,),
        replace=True,
    )
    return replacement_token_ids


def is_prediction_correct(
    logits: torch.Tensor, target_token_ids: torch.Tensor
) -> torch.Tensor:
    max_prob_token_id = logits.argmax(dim=-1)
    predictions_correct = max_prob_token_id == target_token_ids
    return predictions_correct


# def compute_majority_token(
#     logits: torch.Tensor,
# ) -> int:
#     max_prob_token_id = logits.argmax(dim=-1)
#     majority_correct_token = max_prob_token_id.mode().values.item()
#     return majority_correct_token


def compute_target_token_probs(
    logits: torch.Tensor, target_token_ids: torch.Tensor
) -> torch.Tensor:
    token_probs = torch.softmax(logits, dim=-1)
    target_token_probs = token_probs[:, target_token_ids]
    return target_token_probs


def compute_top_k_token_probs(
    logits: torch.Tensor,
    k: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    token_ids, token_probs = _compute_plurality_frequency_probs(logits)
    return token_ids[:k], token_probs[:k]


def compute_entropy(logits: torch.Tensor) -> torch.Tensor:
    _, token_probs = _compute_plurality_frequency_probs(logits)
    entropy = -torch.sum(token_probs * torch.log(token_probs), dim=-1)
    return entropy


def _compute_plurality_frequency_probs(
    logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Computes a ranking of the tokens based on how frequently they
    appear as the plurality prediction.
    Returns a tuple of the token ids and their corresponding probabilities,
    sorted in descending order by the latter.
    """
    # Get the plurality prediction for each sequence
    max_prob_token_id = torch.argmax(logits, dim=-1)

    # Count the number of times each token appears as the plurality
    target_token_freqs = torch.bincount(max_prob_token_id)

    # Sort the tokens by their frequency
    token_order = torch.argsort(target_token_freqs, descending=True)
    target_token_freqs = target_token_freqs[token_order]

    # Only keep the tokens that actually exist, i.e. have a frequency > 0
    existing_token_filter = target_token_freqs > 0
    token_order = token_order[existing_token_filter]
    target_token_freqs = target_token_freqs[existing_token_filter]

    # Normalize the frequencies to probabilities
    target_token_probs = target_token_freqs / torch.sum(target_token_freqs)

    return token_order, target_token_probs
