import logging
from dataclasses import dataclass
from typing import Optional

import numpy as np
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from inference import PredictionConfig, predict
from tqdm import tqdm
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizer

from defs import LLMExperimentConfig
from lib_dl.analysis.experiment import ExperimentTaskDescription, experiment
from utils.context_probing import (
    compute_entropy,
    compute_target_token_probs,
    compute_top_k_token_probs,
    is_prediction_correct,
    sample_random_replacement,
)
from utils.data.random_strings import RandomStringConfig, get_random_strings
from utils.encoding import encode_strings_characterwise
from utils.finetuning.finetune import (
    FinetuningConfig,
    get_finetuned_model_tokenizer,
)


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "prefix_performance"
EXP_ABBR = "pp"


MAX_TOKEN_SAMPLES = 256


@dataclass
class PrefixEvalConfig:
    seed: int
    num_samples_per_prefix: int
    # Add or remove tokens in the non-prefix part, relative value
    relative_non_prefix_size: float = 1.0
    replacement_strategy: str = "rand_id"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: RandomStringConfig
    fine_tuning: FinetuningConfig
    prefix_testing: PrefixEvalConfig


@dataclass
class ExperimentResult:
    training_logs: Optional[pd.DataFrame]
    sequence: str
    prefix_performance: pd.DataFrame


@experiment(EXP_NAME)
def pp_experiment(
    config: ExperimentConfig,
    description: ExperimentTaskDescription,
) -> ExperimentResult:
    data = get_random_strings(config.data)
    assert len(data.strings) == 1

    ft_res = get_finetuned_model_tokenizer(
        config.fine_tuning,
        config.data,
        local_rank=config.local_rank,
    )

    string = data.strings[0]
    prefix_performance = eval_prefix_lengths(
        config.prefix_testing,
        ft_res.model,
        ft_res.tokenizer,
        string,
        data.alphabet,
        config.local_rank,
    )
    print("prefix_performance:", prefix_performance)

    return ExperimentResult(
        training_logs=ft_res.training_log,
        sequence=string,
        prefix_performance=prefix_performance,
    )


def eval_prefix_lengths(
    config: PrefixEvalConfig,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    sequence: str,
    alphabet: str,
    local_rank: int,
) -> pd.DataFrame:
    logger.info("Evaluating prefix lengths")

    sequence_token_ids = (
        encode_strings_characterwise(tokenizer, [sequence]).input_ids[0].numpy()
    )
    replacements = get_replacements(
        config,
        tokenizer,
        len(sequence),
        alphabet,
    )

    prediction_config = PredictionConfig(
        trim_last_token=False,
        local_rank=local_rank,
    )
    token_indices = get_probing_token_indices(len(sequence), config.seed)
    token_results = []
    for token_idx in tqdm(token_indices):
        target_token_id = sequence_token_ids[token_idx]
        token_result = eval_token(
            model,
            tokenizer,
            prediction_config=prediction_config,
            sequence_token_ids=sequence_token_ids,
            replacements=replacements,
            target_token_idx=token_idx,
            target_token_id=target_token_id,
            relative_size=config.relative_non_prefix_size,
        )
        token_results.append(token_result)
    performance = pd.concat(
        token_results,
        axis="index",
        keys=token_indices,
    )
    performance.index.names = ["token_idx", "prefix_length"]
    return performance


def get_replacements(
    config: PrefixEvalConfig,
    tokenizer: PreTrainedTokenizer,
    replacement_length: int,
    alphabet: str,
) -> list[np.ndarray]:
    alphabet_token_ids = (
        encode_strings_characterwise(tokenizer, [alphabet]).input_ids[0].numpy()
    )
    special_tokens_set = set(tokenizer.all_special_ids)
    all_token_ids_set = set(range(tokenizer.vocab_size)) - special_tokens_set
    ood_token_ids = np.array(list(all_token_ids_set - set(alphabet_token_ids)))

    rng = np.random.default_rng(config.seed)
    if config.replacement_strategy == "rand_id":
        replacement_token_ids = [
            alphabet_token_ids
        ] * config.num_samples_per_prefix
    elif config.replacement_strategy == "const_id":
        # Sample one constant token for each prefix length
        replacement_token_ids = [
            rng.choice(alphabet_token_ids, size=1)
            for _ in range(config.num_samples_per_prefix)
        ]
    elif config.replacement_strategy == "rand_ood":
        replacement_token_ids = [ood_token_ids] * config.num_samples_per_prefix
    elif config.replacement_strategy == "const_ood":
        # Sample one constant token for each prefix length
        replacement_token_ids = [
            rng.choice(ood_token_ids, size=1)
            for _ in range(config.num_samples_per_prefix)
        ]
    else:
        raise ValueError(
            f"Unknown replacement strategy: {config.replacement_strategy}"
        )

    replacements = [
        (
            sample_random_replacement(
                replacement_length,
                rep_ids,
                seed=config.seed + i,
            )
        )
        for i, rep_ids in enumerate(replacement_token_ids)
    ]
    if config.relative_non_prefix_size > 1:
        additional_replacement_size = int(
            np.ceil((config.relative_non_prefix_size - 1) * replacement_length)
        )
        # We need additional replacement tokens since we increase the size
        # of the context beyond the original one
        replacements = [
            np.concatenate(
                (
                    sample_random_replacement(
                        additional_replacement_size,
                        rep_ids,
                        seed=config.seed + config.num_samples_per_prefix + i,
                    ),
                    replacement,
                ),
                axis=0,
            )
            for i, (replacement, rep_ids) in enumerate(
                zip(replacements, replacement_token_ids)
            )
        ]
    return replacements


def eval_token(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    prediction_config: PredictionConfig,
    sequence_token_ids: np.ndarray,
    replacements: list[np.ndarray],
    target_token_idx: int,
    target_token_id: torch.Tensor,
    relative_size: float,
) -> pd.DataFrame:
    samples = []
    target_idxs = []
    batch_starts = []
    prefix_lengths = get_prefix_lengths(target_token_idx)
    # The maximum size we might grow the prompt to, in case we are adding
    # additional tokens
    max_context_size = max(
        target_token_idx, round(relative_size * target_token_idx)
    )
    for prefix_length in prefix_lengths:
        batch_starts.append(len(samples))
        prefix_start = target_token_idx - prefix_length
        prior_size = round(relative_size * prefix_start)
        prefix = sequence_token_ids[prefix_start:target_token_idx]

        # We add a dummy token for right padding here. Otherwise the
        # encoding function will add left padding, which will shift the
        # positions and use attention masks, which we want to avoid.
        padding_size = max_context_size - (prior_size + prefix_length)
        filler = tokenizer.pad_token_id * np.ones(padding_size, dtype=int)

        if prefix_start == 0:
            # We only use the prefix, no other tokens in the context
            padded_prefix = np.concatenate((prefix, filler), axis=0)
            samples.append(padded_prefix)
            target_idxs.append(len(prefix) - 1)
        else:
            # There are other tokens in the context, create replacement
            # samples
            for replacement in replacements:
                sample_length = prior_size + prefix_length
                sample = np.concatenate(
                    (replacement[:prior_size], prefix, filler), axis=0
                )
                samples.append(sample)
                target_idxs.append(sample_length - 1)
    assert all(len(sample) == max_context_size for sample in samples)

    encoding = BatchEncoding(
        {
            "input_ids": torch.tensor(np.stack(samples)),
            "attention_mask": torch.ones(len(samples), max_context_size),
        }
    )
    output = predict(model, encoding, prediction_config)
    target_logits = torch.stack(
        [
            output["logits"][i, target_idx]
            for i, target_idx in enumerate(target_idxs)
        ]
    )

    result = aggregate_predictions(
        tokenizer=tokenizer,
        logits=target_logits,
        batch_starts=batch_starts,
        target_token_id=target_token_id,
    )
    result.index = prefix_lengths
    return result


def aggregate_predictions(
    tokenizer: PreTrainedTokenizer,
    logits: torch.Tensor,
    batch_starts: list[int],
    target_token_id: torch.Tensor,
) -> pd.DataFrame:
    predictions_correct = (
        is_prediction_correct(
            logits,
            target_token_id,
        )
        .float()
        .detach()
        .cpu()
        .numpy()
    )
    target_token_probs = (
        compute_target_token_probs(
            logits,
            target_token_id,
        )
        .float()
        .detach()
        .cpu()
        .numpy()
    )

    # Average the correctness and probabilities over the different
    # replacement samples, i.e. compute one mean value for each
    # prefix length
    avg_correct = []
    avg_target_token_probs = []
    top_k = 2
    top_k_tokens = {k: [] for k in range(top_k)}
    top_k_probs = {k: [] for k in range(top_k)}
    entropies = []
    for start, end in zip(batch_starts, batch_starts[1:] + [len(logits)]):
        avg_correct.append(predictions_correct[start:end].mean())
        avg_target_token_probs.append(target_token_probs[start:end].mean())

        top_tokens, top_probs = compute_top_k_token_probs(
            logits[start:end], top_k
        )
        for k in range(top_k):
            if len(top_tokens) <= k:
                top_k_tokens[k].append(np.nan)
                top_k_probs[k].append(np.nan)
            else:
                top_k_token = tokenizer.decode(top_tokens[k].item())
                top_k_tokens[k].append(top_k_token)
                top_k_probs[k].append(top_probs[k].item())

        entropy = compute_entropy(logits[start:end]).item()
        entropies.append(entropy)

    return pd.DataFrame(
        {
            "correct_samples": avg_correct,
            "target_prob": avg_target_token_probs,
            "top_1_token": top_k_tokens[0],
            "top_1_token_prob": top_k_probs[0],
            "top_2_token": top_k_tokens[1],
            "top_2_token_prob": top_k_probs[1],
            "entropy": entropies,
        },
    )


def get_probing_token_indices(sequence_length: int, seed: int) -> np.ndarray:
    all_token_indices = np.arange(1, sequence_length)
    if sequence_length <= MAX_TOKEN_SAMPLES:
        return all_token_indices
    else:
        rng = np.random.default_rng(seed)
        return np.sort(
            rng.choice(
                all_token_indices,
                size=MAX_TOKEN_SAMPLES,
                replace=False,
            )
        )


def get_prefix_lengths(token_idx: int) -> np.ndarray:
    DENSE_SAMPLING_END = 21

    prefix_lengths = np.arange(1, DENSE_SAMPLING_END, 1)
    if token_idx > DENSE_SAMPLING_END:
        max_seq_idx = int(np.ceil(np.sqrt(token_idx - DENSE_SAMPLING_END)))
        quad_spaced_lengths = DENSE_SAMPLING_END + (
            np.arange(1, max_seq_idx, 1) ** 2
        )
        prefix_lengths = np.concatenate([prefix_lengths, quad_spaced_lengths])
    prefix_lengths = prefix_lengths[prefix_lengths <= token_idx]
    if prefix_lengths[-1] != token_idx:
        # Always check the full context
        prefix_lengths = np.concatenate([prefix_lengths, [token_idx]])
    return prefix_lengths
