import logging
from dataclasses import dataclass
from typing import Optional

import numpy as np
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from inference import (
    PredictionConfig,
    TokenValues,
    compute_token_probs,
    predict,
)
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizer
from transformers.utils import ModelOutput

from defs import LLMExperimentConfig
from lib_dl.analysis.experiment import ExperimentTaskDescription, experiment
from lib_llm.inference import tokenize
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train
from utils.data import RandomStringConfig, generate_random_strings
from utils.encoding import encode_data_characterwise


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "relevant_context_search"


@dataclass
class ContextSearchConfig:
    seed: int
    num_samples_per_sequence: int
    num_shuffle_samples_per_token: int


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    model: ModelConfig
    data: RandomStringConfig
    fine_tuning: TrainingConfig
    context_search: ContextSearchConfig


@dataclass
class ExperimentResult:
    training_logs: Optional[pd.DataFrame]
    sequences: list[str]
    attentions: np.ndarray
    probabilities: np.ndarray
    # The minimum number of context token needed to predict the target
    # token correctly
    # - When picking contexts immediately preceding the target token,
    # - When picking contexts according to descending attention weight
    # for each layer, for each attention head, as well as the mean
    min_token_context_sizes: pd.DataFrame


@experiment(EXP_NAME)
def rcs_experiment(
    config: ExperimentConfig,
    description: ExperimentTaskDescription,
) -> ExperimentResult:
    if not config.fine_tuning.train:
        # Set pretrained to false, since we're later loading the already
        # fine-tuned checkpoint
        config.model.pretrained = False
    model, tokenizer = load_model_tokenizer(
        config.model,
    )

    dataset = generate_random_strings(config.data)
    encoded_dataset = encode_data_characterwise(tokenizer, dataset.data)
    training_dataset = encoded_dataset.remove_columns(["text"])
    logger.info(f"Generated {len(encoded_dataset['test'])} sequences")

    training_res = train(
        description,
        (config.model.model_id_not_none, model),
        ("random_strings", training_dataset),
        config=config.fine_tuning.with_local_rank(config.local_rank),
        tokenizer=tokenizer,
        callbacks=[],
        data_already_preprocessed=True,
    )
    model = training_res.model
    model.name_or_path = config.model.model_id_not_none

    alphabet_token_ids = tokenize(
        tokenizer,
        list(dataset.alphabet_characters),
        max_length=1,
    ).input_ids.squeeze(1)
    eval_data = encoded_dataset["test"]
    model.config.output_attentions = True
    output, log_probs = compute_predictions(
        model,
        eval_data,
        config.local_rank,
    )
    attentions = torch.stack(
        [
            layer_attention.detach().cpu()
            for layer_attention in output["attentions"]
        ]
    )
    min_context_sizes = find_min_token_context_sizes(
        config.context_search,
        model,
        data=eval_data,
        attentions=attentions,
        alphabet_token_ids=alphabet_token_ids,
        local_rank=config.local_rank,
    )

    return ExperimentResult(
        # training_logs=training_res.training_logs,
        training_logs=None,
        sequences=eval_data["text"],
        attentions=attentions.to(dtype=torch.float32).numpy(),
        probabilities=(
            log_probs.values.to(dtype=torch.float32).detach().cpu().numpy()
        ),
        min_token_context_sizes=min_context_sizes,
    )


def compute_predictions(
    model: PreTrainedModel,
    data: Dataset,
    local_rank: int,
) -> tuple[ModelOutput, TokenValues]:
    config = PredictionConfig(
        trim_last_token=True,
        local_rank=local_rank,
    )
    batch_encoding = BatchEncoding(
        {
            "input_ids": data["input_ids"],
            "attention_mask": data["attention_mask"],
        }
    )
    output = predict(model, batch_encoding, config)
    log_probs = compute_token_probs(output["logits"], batch_encoding)
    return output, log_probs


def find_min_token_context_sizes(
    config: ContextSearchConfig,
    model: PreTrainedModel,
    data: Dataset,
    attentions: torch.Tensor,
    alphabet_token_ids: torch.Tensor,
    local_rank: int,
) -> pd.DataFrame:
    rng = np.random.default_rng(config.seed)
    torch.random.manual_seed(config.seed)
    pred_config = PredictionConfig(
        trim_last_token=False,
        local_rank=local_rank,
    )

    min_context_sizes = {}
    for sequence_idx, sequence in enumerate(data):
        sequence_text = sequence["text"]
        sequence_input_ids = sequence["input_ids"]
        logger.info(f"Processing sequence {sequence_text}")

        for i, token_idx in enumerate(
            rng.choice(
                len(sequence_input_ids) - 1,
                size=config.num_samples_per_sequence,
                replace=False,
            )
        ):
            # We can't predict the first token, since it has no context
            token_idx += 1
            print(
                f"Token idx: {token_idx}, {i + 1}/{config.num_samples_per_sequence}"
            )
            token_id = sequence_input_ids[token_idx]

            # We do binary search over the index sizes
            position_order_scores = torch.arange(token_idx, dtype=torch.float32)
            context_types = {
                "preceding_masked": (position_order_scores, True),
                "preceding_shuffled": (position_order_scores, False),
            }
            num_layers = attentions.shape[0]
            # Compute contexts for attention at the first, the middle
            # and the last layer
            for layer_idx in [0, round(num_layers / 2), num_layers - 1]:
                attention_scores = (
                    attentions[layer_idx, sequence_idx].mean(dim=0)
                    # Attention scores only start from the 2nd token,
                    # So the first attention score corresponds to the second
                    # token etc.
                    [token_idx - 1, :token_idx]
                )
                context_types[f"attention_{layer_idx}_masked"] = (
                    attention_scores,
                    True,
                )
                context_types[f"attention_{layer_idx}_shuffled"] = (
                    attention_scores,
                    False,
                )
            input_ids = sequence_input_ids[:token_idx]
            search_states = _init_search_states(
                context_types,
                seed=config.seed,
                max_context_size=token_idx,
                num_input_ids_samples=config.num_shuffle_samples_per_token,
                orig_input_ids=input_ids,
                alphabet_token_ids=alphabet_token_ids,
            )

            while not all(ss.is_done for ss in search_states.values()):
                contexts = _stack_batch_encodings(
                    [
                        ss.construct_batch_encoding(input_ids)
                        for ss in search_states.values()
                    ]
                )
                output = predict(model, contexts, pred_config)
                token_logits = output["logits"][:, -1]
                max_prob_token_id = token_logits.argmax(dim=-1)
                predictions_correct = max_prob_token_id == token_id
                _update_search_states(
                    search_states,
                    predictions_correct,
                    max_context_size=token_idx,
                )
            for (ss_name, ss_sample_idx), search_state in search_states.items():
                assert search_state.context_size <= token_idx
                # Check that if we're not able to predict the target token
                # correctly, we tried the largest context size to do so.
                assert (
                    search_state.is_correct
                    or search_state.context_size == token_idx
                )
                min_context_sizes[
                    (sequence_text, token_idx, ss_name, ss_sample_idx)
                ] = (
                    search_state.context_size if search_state.is_correct else -1
                )
    context_search_results = pd.DataFrame(
        list(min_context_sizes.values()),
        columns=["min_context_size"],
        index=pd.MultiIndex.from_tuples(
            min_context_sizes.keys(),
            names=["sequence", "token_idx", "context_type", "sample_idx"],
        ),
        dtype=int,
    )
    return context_search_results


@dataclass
class MinContextSearchState:
    context_size: int
    last_delta: int
    token_scores: torch.Tensor
    replacement_token_ids: torch.Tensor
    is_masked: bool
    is_correct: bool = False

    @property
    def is_done(self) -> bool:
        return self.last_delta == 0

    def construct_batch_encoding(
        self,
        orig_input_ids: torch.Tensor,
    ) -> BatchEncoding:
        attention_order = torch.argsort(self.token_scores, descending=True)
        non_target_attention_indices = attention_order[self.context_size :]

        if self.is_masked:
            # We mask out the non-context tokens
            input_ids = orig_input_ids
            attention_mask = torch.ones_like(orig_input_ids)
            attention_mask[non_target_attention_indices] = 0
        else:
            # We replace the non-context tokens with random ones
            input_ids = orig_input_ids.clone()
            replacement = self.replacement_token_ids[
                non_target_attention_indices
            ]
            input_ids[non_target_attention_indices] = replacement
            attention_mask = torch.ones_like(orig_input_ids)

        return BatchEncoding(
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
            }
        )


def _init_search_states(
    context_type_infos: dict[str, tuple[torch.Tensor, bool]],
    seed: int,
    max_context_size: int,
    num_input_ids_samples: int,
    orig_input_ids: torch.Tensor,
    alphabet_token_ids: torch.Tensor,
) -> dict[tuple[str, int], MinContextSearchState]:
    # Initialize the context size to half the max size (i.e. half of the
    # target token index), for binary search.
    start_context_size = max(1, max_context_size // 2)

    search_states = {}
    rng = torch.Generator(device=orig_input_ids.device).manual_seed(seed)
    for context_type, (token_scores, is_masked) in context_type_infos.items():
        if is_masked:
            search_states[(context_type, 0)] = MinContextSearchState(
                context_size=start_context_size,
                last_delta=start_context_size,
                token_scores=token_scores,
                replacement_token_ids=orig_input_ids,
                is_masked=is_masked,
            )
        else:
            # Randomly shuffled tokens
            for i in range(num_input_ids_samples):
                replacement_token_indices = torch.randint(
                    len(alphabet_token_ids),
                    size=(max_context_size,),
                    generator=rng,
                )
                replacement_token_ids = alphabet_token_ids[
                    replacement_token_indices
                ]
                search_states[(context_type, i)] = MinContextSearchState(
                    context_size=start_context_size,
                    last_delta=start_context_size,
                    token_scores=token_scores,
                    replacement_token_ids=replacement_token_ids,
                    is_masked=is_masked,
                )
    return search_states


def _stack_batch_encodings(
    batch_encodings: list[BatchEncoding],
) -> BatchEncoding:
    return BatchEncoding(
        {
            k: torch.stack([be[k] for be in batch_encodings])
            for k in batch_encodings[0].keys()
        }
    )


def _update_search_states(
    search_states: dict[tuple, MinContextSearchState],
    predictions_correct: torch.Tensor,
    max_context_size: int,
) -> None:
    for search_state, is_correct in zip(
        search_states.values(),
        predictions_correct,
    ):
        is_correct = bool(is_correct.item())
        last_delta_magnitude = abs(search_state.last_delta)
        if last_delta_magnitude > 1:
            if is_correct:
                # We can lower the context size
                delta = -(last_delta_magnitude // 2)
            else:
                # We need to increase the context size
                delta = last_delta_magnitude // 2
        else:
            if (
                search_state.context_size == max_context_size
                or is_correct
                and search_state.context_size == 1
            ):
                delta = 0
            else:
                if is_correct:
                    if search_state.last_delta < 0:
                        delta = -1
                    else:
                        delta = 0
                else:
                    delta = 1
        search_state.last_delta = delta
        search_state.context_size += delta
        search_state.is_correct = is_correct
