import logging
import os
from dataclasses import dataclass, replace
from functools import partial
from typing import cast

import numpy as np
import pandas as pd
import torch
from datasets import DatasetDict
from transformers import PreTrainedModel, PreTrainedTokenizer, TrainerCallback

from data.synthetic_strings import (
    generate_random_strings,
    get_config,
    uses_bos_token,
)
from data.synthetic_strings.random import RandomStringConfig, RandomStringData
from data.text_generation import (
    TextGenerationDataConfig,
    load_text_generation_data,
)
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.ops.batch_size import get_batch_size
from lib_llm.training.train import TrainingConfig, train
from lib_project.experiment import ExperimentID, NoSave, NoSaveValue, experiment
from utils.memorization.memorization import MemorizationTrainingResult


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()
NUM_DEVICES = torch.cuda.device_count() if HAS_CUDA else 1
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "-1"))

EXP_NAME = "practical_memorization_dynamics"
EXP_ABBREVIATION = "pmd"


@dataclass
class ContextDataConfig:
    dataset: str
    dataset_variant: str
    seed: int
    sequence_length: int
    batch_size: int
    inject_every_n_steps: int = 1

    @property
    def name(self) -> str:
        if self.dataset_variant.startswith(self.dataset):
            dataset_name = self.dataset_variant
        else:
            dataset_name = f"{self.dataset}-{self.dataset_variant}"
        return (
            f"{dataset_name}_l-{self.sequence_length}_"
            f"b-{self.batch_size}_inj-{self.inject_every_n_steps}"
        )


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    random_data: RandomStringConfig
    context_data: ContextDataConfig
    model: ModelConfig
    training: TrainingConfig


@dataclass
class ExperimentResult:
    random_data: RandomStringData | None
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame
    model: NoSave[PreTrainedModel]
    tokenizer: NoSave[PreTrainedTokenizer]


@experiment(EXP_NAME)
def pmd_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(config.model)
    tokenizer.model_max_length = 2048
    data_name, random_data, injected_data = create_injected_dataset(
        config.random_data,
        config.context_data,
        num_epochs=cast(int, config.training.args.num_train_epochs),
        tokenizer=tokenizer,
    )
    training_res = train_model(
        experiment_id,
        config.training,
        (config.model.model_id_not_none, model),
        tokenizer,
        data_name,
        random_data,
        injected_data,
    )
    training_log = training_res.training_log
    assert training_log is not None
    memorization_log = training_res.memorization_log
    assert memorization_log is not None

    return ExperimentResult(
        random_data=random_data,
        training_log=training_res.training_log,
        memorization_log=training_res.memorization_log,
        model=NoSaveValue(training_res.model),
        tokenizer=NoSaveValue(training_res.tokenizer),
    )


def create_injected_dataset(
    random_data_config: RandomStringConfig,
    context_data_config: ContextDataConfig,
    num_epochs: int,
    tokenizer: PreTrainedTokenizer,
) -> tuple[str, RandomStringData, DatasetDict]:
    """Create a dataset based on context data (real world or random)
    into which the same random string is injected into repreatedly,
    once per batch."""

    assert random_data_config.num_partitions == 1
    random_data = generate_random_strings(random_data_config, tokenizer)

    batch_size = context_data_config.batch_size
    context_config = TextGenerationDataConfig(
        dataset=context_data_config.dataset,
        variant=context_data_config.dataset_variant,
        seed=context_data_config.seed,
        sequence_length=context_data_config.sequence_length,
    )
    context_data = load_text_generation_data(
        context_config,
        tokenizer,
    )

    rand_input_ids = torch.from_numpy(random_data.raw_token_ids[0])
    rng = np.random.default_rng(context_data_config.seed)
    (
        insertion_margin,
        first_insertion_token,
        rand_input_ids,
    ) = _get_insertion_pos(
        rand_input_ids,
        uses_bos_token(random_data_config.tokenizer_type),
        context_data_config.sequence_length,
    )
    insert_random_string = partial(
        _insert_random_string,
        rng=rng,
        rand_input_ids=rand_input_ids,
        batch_size=batch_size,
        insertion_margin=insertion_margin,
        first_insertion_token=first_insertion_token,
    )

    # Total number of context examples for the entire training runs
    num_training_examples = (
        context_data_config.batch_size
        * context_data_config.inject_every_n_steps
        * num_epochs
    )
    print("num_training_examples:", num_training_examples)
    print("num dataset examples:", len(context_data["train"]))
    injected_training_data = (
        context_data["train"]
        .select(range(num_training_examples))
        .map(
            insert_random_string,
            batched=True,
        )
    )
    injected_data = DatasetDict(
        {
            "train": injected_training_data,
            "validation": context_data["validation"],
            "test": context_data["test"],
        }
    )

    data_name = f"{random_data_config.name}_{context_data_config.name}"
    return data_name, random_data, injected_data


def _get_insertion_pos(
    rand_input_ids: torch.Tensor,
    uses_bos_token: bool,
    context_length: int,
) -> tuple[int, int, torch.Tensor]:
    # Don't overwrite the start of string tokens if it's used by the tokenizer
    first_insertion_token = 1 if uses_bos_token else 0
    if len(rand_input_ids) < context_length or not uses_bos_token:
        rand_ids_len = len(rand_input_ids)
    else:
        rand_ids_len = context_length - 1
    insertion_margin = context_length - (rand_ids_len + first_insertion_token)
    return (
        insertion_margin,
        first_insertion_token,
        rand_input_ids[:rand_ids_len],
    )


def _insert_random_string(
    examples: dict,
    *,
    rng: np.random.Generator,
    rand_input_ids: torch.Tensor,
    batch_size: int,
    insertion_margin: int,
    first_insertion_token: int,
) -> dict:
    input_ids = examples["input_ids"]
    for block in range(0, len(input_ids), batch_size):
        block_end = min(block + batch_size, len(input_ids))
        insertion_pos = rng.integers(block, block_end)
        if insertion_margin <= first_insertion_token:
            insertion_start = first_insertion_token
        else:
            insertion_start = rng.integers(
                first_insertion_token, insertion_margin
            )
        insertion_end = insertion_start + len(rand_input_ids)
        input_ids[insertion_pos] = torch.cat(
            [
                input_ids[insertion_pos][:insertion_start],
                rand_input_ids,
                input_ids[insertion_pos][insertion_end:],
            ]
        )
    return examples


def train_model(
    experiment_id: ExperimentID,
    training_config: TrainingConfig,
    model_info: tuple[str, PreTrainedModel],
    tokenizer: PreTrainedTokenizer,
    data_name: str,
    random_data: RandomStringData,
    injected_data: DatasetDict,
) -> MemorizationTrainingResult:
    random_dataset = random_data.dataset()
    memorization_task = memorization_dynamics_metrics(
        random_data.alphabet_tokens,
        random_data.alphabet_token_ids,
        random_dataset["test"],
    )
    eval_tasks: list[TrainerCallback] = [memorization_task]

    sequence_length = cast(
        torch.Tensor, injected_data["test"]["input_ids"]
    ).shape[1]
    print("sequence_length:", sequence_length)
    eval_batch_size = get_batch_size(
        model_info[1],
        sequence_length,
        local_rank=0 if LOCAL_RANK < 0 else LOCAL_RANK,
        action="inference",
    )
    train_batch_size = int(
        np.ceil(training_config.args.per_device_train_batch_size / NUM_DEVICES)
    )

    print("eval_batch_size:", eval_batch_size)
    # We control the number of epochs via the size of the context
    # that we repeatedly inject the same string into
    training_config = replace(
        training_config,
        args=replace(
            training_config.args,
            num_train_epochs=1,
            per_device_eval_batch_size=eval_batch_size,
            per_device_train_batch_size=train_batch_size,
        ),
    )
    training_res = train(
        experiment_id,
        model_info,
        (data_name, injected_data),
        config=training_config,
        tokenizer=tokenizer,
        callbacks=eval_tasks,
        data_already_preprocessed=True,
    )
    training_log = training_res.training_log
    assert training_log is not None
    memorization_log = memorization_task.result()
    return MemorizationTrainingResult(
        model=training_res.model,
        tokenizer=tokenizer,
        data=random_data,
        training_log=training_log,
        memorization_log=memorization_log,
    )
