import logging
from dataclasses import dataclass
from typing import Literal, Optional

import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizer

from defs import LLMExperimentConfig
from lib_dl.analysis.experiment import ExperimentTaskDescription, experiment
from lib_llm.eval.sequences import (  # SequenceMetric,; EntropyMetric,
    SequenceEvaluationTask,
    SequenceMetric,
)
from lib_llm.inference import (
    PredictionConfig,
    TokenValues,
    compute_token_probs,
    predict,
    tokenize,
)
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train
from utils.context_probing import (
    ContextProbingConfig,
    ContextType,
    ProbingContexts,
    construct_context,
)
from utils.data import RandomStringConfig, generate_random_strings
from utils.encoding import encode_data_characterwise


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "context_size"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    model: ModelConfig
    data: RandomStringConfig
    fine_tuning: TrainingConfig
    context_probing: ContextProbingConfig


@dataclass
class ExperimentResult:
    training_logs: Optional[pd.DataFrame]
    context_results: pd.DataFrame


@experiment(EXP_NAME)
def cs_experiment(
    config: ExperimentConfig,
    description: ExperimentTaskDescription,
) -> ExperimentResult:
    if not config.fine_tuning.train:
        # We load the trained model later, so no need to load weights here
        config.model.pretrained = False
    model, tokenizer = load_model_tokenizer(
        config.model,
    )
    dataset = generate_random_strings(config.data)
    encoded_dataset = encode_data_characterwise(tokenizer, dataset.data)
    training_dataset = encoded_dataset.remove_columns(["text"])
    logger.info(f"Generated {len(encoded_dataset['test'])} sequences")

    training_res = train(
        description,
        (config.model.model_id_not_none, model),
        ("random_strings", training_dataset),
        config=config.fine_tuning.with_local_rank(config.local_rank),
        tokenizer=tokenizer,
        callbacks=[],
        data_already_preprocessed=True,
    )
    model = training_res.model

    reference_log_probs = compute_reference_log_probs(
        model,
        encoded_dataset["test"],
        config.local_rank,
    )
    alphabet_token_ids = tokenize(
        tokenizer,
        list(dataset.alphabet_characters),
        max_length=1,
    ).input_ids.squeeze(1)
    distinct_token_ids = tokenize(
        tokenizer,
        ["0"],
        max_length=1,
    ).input_ids.squeeze(1)
    context_types: list[ContextType] = [
        "full_prefix",
        "local_prefix",
        "masked_prefix",
        "random_prefix",
        "scattered_prefix",
        "in_alpha_prefix",
        "oo_alpha_prefix",
        "all_random",
    ]
    context_results: dict[str, pd.DataFrame] = {}
    target_token_idxs = None
    for context_type in context_types:
        logger.info(f"Evaluating context type {context_type}")
        context = construct_context(
            config.context_probing,
            context_type,
            tokenizer=tokenizer,
            dataset=encoded_dataset,
            alphabet_token_ids=alphabet_token_ids,
            distinct_token_ids=distinct_token_ids,
        )
        context_result = evaluate_contexts(
            model,
            context,
            reference_log_probs=reference_log_probs,
            local_rank=config.local_rank,
        )
        context_results[context_type] = context_result
        # Make sure we're always computing probabilities for the same
        # target tokens
        if target_token_idxs is None:
            target_token_idxs = context_result.index.get_level_values(
                "sequence"
            )
        else:
            assert target_token_idxs.equals(
                context_result.index.get_level_values("sequence")
            )
    combined_context_results = _combine_context_results(context_results)
    print("combined_context_results", combined_context_results)

    return ExperimentResult(
        training_logs=training_res.training_log,
        context_results=combined_context_results,
    )


def compute_reference_log_probs(
    model: PreTrainedModel,
    data: Dataset,
    local_rank: int,
) -> TokenValues:
    config = PredictionConfig(
        trim_last_token=True,
        local_rank=local_rank,
    )
    batch_encoding = BatchEncoding(
        {
            "input_ids": data["input_ids"],
            "attention_mask": data["attention_mask"],
        }
    )
    output = predict(model, batch_encoding, config)
    reference_probs = compute_token_probs(output["logits"], batch_encoding)
    return reference_probs


def evaluate_contexts(
    model: PreTrainedModel,
    contexts: ProbingContexts,
    reference_log_probs: TokenValues,
    local_rank: int,
) -> pd.DataFrame:
    target_token_reference_log_probs = _get_target_token_log_probs(
        reference_log_probs.values.expand(
            len(contexts.target_tokens_idxs),
            -1,
            -1,
        ),
        torch.tensor(contexts.target_tokens_idxs, dtype=torch.long),
    )
    metrics: dict[str, SequenceMetric] = {
        "correct": ContextCorrectnessMetric(
            contexts.context_target_dixs, contexts.target_tokens_ids
        ),
        "kld": ContextKLDMetric(
            contexts.context_target_dixs,
            target_token_reference_log_probs,
        ),
    }
    eval_task = SequenceEvaluationTask(
        metrics,
        sequences=(
            contexts.target_tokens_idxs,
            contexts.encoded_contexts,
        ),
        tokenizer=None,
        local_rank=local_rank,
    )
    context_performance = eval_task.evaluate(model)
    context_performance.index.names = ["sequence"]
    return context_performance


class ContextCorrectnessMetric(SequenceMetric):
    def __init__(
        self,
        target_token_idxs: torch.Tensor,
        target_token_ids: torch.Tensor,
    ):
        super().__init__(["token_probs"])
        self.target_token_idxs = target_token_idxs
        self.target_token_ids = target_token_ids

        self.correct_predictions: torch.Tensor
        self.add_state(
            "correct_predictions",
            default=torch.zeros_like(self.target_token_idxs),
            dist_reduce_fx="sum",
        )

    def update(self, token_probs: TokenValues) -> None:
        target_token_probs = _get_target_token_log_probs(
            token_probs.values,
            self.target_token_idxs,
        )
        max_prob_token_ids = target_token_probs.argmax(dim=-1)
        self.correct_predictions = max_prob_token_ids == self.target_token_ids

    def compute(self) -> torch.Tensor:
        return self.correct_predictions


class ContextKLDMetric(SequenceMetric):
    def __init__(
        self,
        target_token_idxs: torch.Tensor,
        target_token_log_probs: torch.Tensor,
    ):
        super().__init__(["token_probs"])
        self.target_token_idxs = target_token_idxs
        self.target_token_log_probs = target_token_log_probs

        self.kld: torch.Tensor
        self.add_state(
            "kld",
            default=torch.zeros_like(self.target_token_idxs),
            dist_reduce_fx="sum",
        )

    def update(self, token_probs: TokenValues) -> None:
        token_log_probs = _get_target_token_log_probs(
            token_probs.values,
            self.target_token_idxs,
        )
        self.kld = torch.sum(
            (
                torch.exp(token_log_probs)
                * (token_log_probs - self.target_token_log_probs)
            ),
            dim=-1,
        )

    def compute(self) -> torch.Tensor:
        return self.kld


def _get_target_token_log_probs(
    token_probs: torch.Tensor,
    target_token_idxs: torch.Tensor,
) -> torch.Tensor:
    # We subtract 1 from the token idxs, because there are only
    # probabilities starting from the second token
    # I.e. the proability for the i-th token is at index i - 1
    target_token_log_probs = torch.stack(
        [token_probs[i, idx - 1] for i, idx in enumerate(target_token_idxs)]
    )
    return target_token_log_probs


def _combine_context_results(
    context_results: dict[str, pd.DataFrame],
) -> pd.DataFrame:
    return pd.concat(
        context_results,
        keys=context_results.keys(),
        names=["context_type", "sequence"],
    )
