import logging
from dataclasses import dataclass

import pandas as pd
import torch
from transformers import BatchEncoding, PreTrainedTokenizer

from defs import LLMExperimentConfig
from lib_dl.analysis.experiment import ExperimentTaskDescription, experiment
from lib_llm.eval.sequences import (  # SequenceMetric,; EntropyMetric,
    SequenceEvaluationTask,
    SequenceMetric,
)
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train
from utils.data.random_strings import (
    RandomStringConfig,
    RandomStringData,
    get_random_strings,
)

# from utils.data.wiki import load_wiki_text
from utils.encoding import encode_data_characterwise, encode_data_naturally
from utils.finetuning.finetune import (
    get_sequence_metrics,
    postprocess_memorization_log,
)


# from ..named_entity_detection.data import load_random_names, load_scientists


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "memorization_hyperparam_rel"

# - Plot mean entropy of the character distribution at each position
# - Plot mean KLD of the character distribution at each position
# from the previous position
# - Plot mean pairwise KLD of the character distributions in different
# sequences
# - Plot the probability of each character at each position
# - Look at both the distributions over only letters as well as over all
# possible tokens


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    model: ModelConfig
    data: RandomStringConfig
    fine_tuning: TrainingConfig
    characterwise_tokenization: bool


@dataclass
class ExperimentResult:
    training_history: pd.DataFrame
    loss: pd.DataFrame
    token_distributions: pd.DataFrame


@experiment(EXP_NAME)
def mhr_experiment(
    config: ExperimentConfig,
    description: ExperimentTaskDescription,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(
        config.model,
    )
    data = load_data(config.data)
    dataset = data.dataset()
    encoded_dataset = encode_data_characterwise(tokenizer, dataset)
    if config.characterwise_tokenization:
        encoded_dataset = encode_data_characterwise(tokenizer, dataset)
    else:
        encoded_dataset = encode_data_naturally(tokenizer, dataset)
        raise NotImplementedError("TODO: Set the alphabet accordingly")
    training_dataset = encoded_dataset.remove_columns(["text"])
    logger.info(f"Generated {len(encoded_dataset['test'])} sequences")

    metrics: dict[str, SequenceMetric] = get_sequence_metrics(
        tokenizer,
        config.data,
        data.alphabet,
        data.strings,
    )
    character_eval_task = SequenceEvaluationTask(
        metrics,
        sequences=encoded_dataset["test"],
        tokenizer=None,
        local_rank=config.local_rank,
    )
    training_res = train(
        description,
        (config.model.model_id_not_none, model),
        ("random_strings", training_dataset),
        config=config.fine_tuning.with_local_rank(config.local_rank),
        tokenizer=tokenizer,
        callbacks=[character_eval_task],
        data_already_preprocessed=True,
    )
    model = training_res.model

    eval_logs = postprocess_memorization_log(character_eval_task.result())
    print(eval_logs["memorization"])
    training_logs = training_res.training_log
    assert training_logs is not None
    return ExperimentResult(
        training_history=training_logs,
        loss=eval_logs["loss"],
        token_distributions=eval_logs["memorization"],
    )


def load_data(config: RandomStringConfig) -> RandomStringData:
    return get_random_strings(config)

    # if config.data_type == "rand":
    #     dataset = generate_random_strings(config)
    # elif config.data_type == "rand-names":
    #     dataset = load_random_names(
    #         num_names=config.num_sequences, seed=config.seed
    #     )
    #     raise NotImplementedError("TODO: Set the alphabet accordingly")
    # elif config.data_type == "sci-names":
    #     dataset = load_scientists(
    #         num_names=config.num_sequences, seed=config.seed
    #     )
    #     raise NotImplementedError("TODO: Set the alphabet accordingly")
    # elif config.data_type == "wiki":
    #     dataset = load_wiki_text(config)
    #     raise NotImplementedError("TODO: Set the alphabet accordingly")
    # else:
    #     raise ValueError(f"Unknown data type: {config.data_type}")
    # return dataset
