import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Optional, Union, cast

import pandas as pd
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BatchEncoding,
    PreTrainedModel,
    PreTrainedTokenizer,
)

from ..data.random_strings import SEEDS, RandomStringConfig, get_random_strings


logger = logging.getLogger(__name__)
FILE_DIR = Path(__file__).parent
ROOT_PATH = FILE_DIR.parent.parent.parent

CONTAINER_TYPE = os.environ.get("CONTAINER_TYPE", "runtime")
if CONTAINER_TYPE == "devel":
    logger.info("Using deepspeed offloading for finetuning")
    DEEPSPEED_CONFIG = str(FILE_DIR / "ds_memorization_offload_config.json")
else:
    DEEPSPEED_CONFIG = str(FILE_DIR / "ds_memorization_config.json")

MODELS_DIR = "models"

# Learning rates were tested in the MHR experiment, with a linear
# learning rate schedule
LEARNING_RATE_OVERRIDES = {
    "gpt2": 1e-3,
    "gpt2-xl": 1e-4,
}
# Works well for Pythia 1bn and above models
DEFAULT_LEARNING_RATE = 1e-5


@dataclass
class FinetuningConfig:
    seed_id: int
    model_id: str
    # data: Union[None, str, RandomStringConfig]
    epochs: int
    pretrained: bool = True
    save_model: bool = False
    base_model_dir: Path | None = ROOT_PATH / "base_models"


@dataclass
class FinetuningResult:
    storage_sub_dir: Path
    model: PreTrainedModel
    tokenizer: PreTrainedTokenizer
    training_log: pd.DataFrame
    loss_log: pd.DataFrame
    memorization_log: pd.DataFrame

    def save(self, storage_base_dir: Path) -> None:
        storage_dir = storage_base_dir / self.storage_sub_dir
        storage_dir.mkdir(parents=True, exist_ok=True)
        _save_log(self.training_log, storage_dir, "training")
        _save_log(self.loss_log, storage_dir, "loss")
        _save_log(self.memorization_log, storage_dir, "memorization")

    @classmethod
    def load(
        cls, storage_base_dir: Path, storage_sub_dir
    ) -> "FinetuningResult":
        storage_dir = storage_base_dir / storage_sub_dir
        model = AutoModelForCausalLM.from_pretrained(storage_dir)
        tokenizer = AutoTokenizer.from_pretrained(storage_dir)
        training_log = load_log(storage_dir, "training")
        loss_log = load_log(storage_dir, "loss")
        memorization_log = load_log(storage_dir, "memorization")
        return cls(
            storage_sub_dir,
            model,
            tokenizer,
            training_log=training_log,
            loss_log=loss_log,
            memorization_log=memorization_log,
        )


def get_finetuned_model_tokenizer(
    config: FinetuningConfig,
    data_config: RandomStringConfig,
    local_rank: int = -1,
) -> FinetuningResult:
    model_id_saved = (
        config.model_id if config.pretrained else "f{config.model_id}_untrained"
    )
    storage_sub_dir = _get_storage_sub_dir(
        model_id_saved,
        data_config.name,
        config.epochs,
    )
    storage_dir = data_config.storage_dir / storage_sub_dir
    if (storage_dir / "config.json").exists():
        training_res = FinetuningResult.load(
            data_config.storage_dir, storage_sub_dir
        )
    else:
        training_res = _finetune_random_model(
            config,
            data_config,
            model_id_saved,
            local_rank=local_rank,
        )
        training_res.storage_sub_dir = storage_sub_dir
        training_res.save(data_config.storage_dir)
    return training_res


LogType = Literal["training", "loss", "memorization"]


def load_log(storage_dir: Path, log_type: LogType) -> pd.DataFrame:
    return pd.read_parquet(_get_log_file(storage_dir, log_type))


def _save_log(
    log: pd.DataFrame,
    storage_dir: Path,
    log_type: LogType,
) -> None:
    storage_file = _get_log_file(storage_dir, log_type)
    storage_file.parent.mkdir(parents=True, exist_ok=True)
    log.to_parquet(storage_file)


def _get_log_file(storage_dir: Path, log_type: LogType) -> Path:
    return storage_dir / "logs" / f"{log_type}.parquet"


def _get_storage_sub_dir(
    model_id: str,
    data_config_id: str,
    epoch: int,
) -> Path:
    model_data_dir = (
        Path(MODELS_DIR) / model_id / data_config_id / f"epoch_{epoch}"
    )
    return model_data_dir


def _finetune_random_model(
    config: FinetuningConfig,
    data_config: RandomStringConfig,
    model_id_saved: str,
    local_rank: int,
) -> FinetuningResult:
    # Import here to make the code usable without the lib_llm and lib_dl
    # packages
    from lib_dl.models.io.description import TaskDescription
    from lib_llm.eval.sequences import SequenceEvaluationTask, SequenceMetric
    from lib_llm.models import ModelConfig, load_model_tokenizer
    from lib_llm.training import TrainingArguments, TrainingConfig, train

    from ..encoding import encode_data_characterwise

    assert local_rank > -1, "local_rank must be >= 0"
    model_config = ModelConfig(
        model_id=config.model_id,
        base_dir=config.base_model_dir,
        pretrained=True,
    )
    model, tokenizer = load_model_tokenizer(model_config)

    assert isinstance(data_config, RandomStringConfig)
    data = get_random_strings(data_config)
    dataset = data.dataset()
    encoded_dataset = encode_data_characterwise(tokenizer, dataset)
    logger.info(f"Generated {len(encoded_dataset['test'])} sequences")

    description = TaskDescription(
        task_prefixes=[
            data_config.storage_dir.stem,
            MODELS_DIR,
        ],
        model=model_id_saved,
        dataset=data_config.name,
    )
    training_config = TrainingConfig(
        seed=SEEDS[config.seed_id],
        args=TrainingArguments(
            deepspeed=DEEPSPEED_CONFIG,
            num_train_epochs=config.epochs,
            eval_steps=1,
            per_device_train_batch_size=len(data.strings),
            logging_steps=1,
            save_strategy="no",
        ),
        optimizer=_get_optimizer_config(model_config.model_id_not_none),
        save_final_checkpoint=config.save_model,
        wandb_project_name="llms_rsm",
    )
    metrics: dict[str, SequenceMetric] = get_sequence_metrics(
        tokenizer,
        data_config,
        data.alphabet,
        data.strings,
    )
    character_eval_task = SequenceEvaluationTask(
        metrics,
        sequences=encoded_dataset["test"],
        tokenizer=None,
        local_rank=local_rank,
    )

    training_res = train(
        description,
        (model_id_saved, model),
        (data_config.name, encoded_dataset),
        config=training_config.with_local_rank(local_rank),
        tokenizer=tokenizer,
        callbacks=[character_eval_task],
        data_already_preprocessed=True,
        set_action=False,
    )
    model = training_res.model
    model.name_or_path = model_config.model_id_not_none
    training_log = training_res.training_log
    assert training_log is not None
    eval_logs = postprocess_memorization_log(character_eval_task.result())

    return FinetuningResult(
        storage_sub_dir=Path(),
        model=model,
        tokenizer=tokenizer,
        training_log=training_log,
        loss_log=eval_logs["loss"],
        memorization_log=eval_logs["memorization"],
    )


def _get_optimizer_config(model_id: str):
    from lib_llm.training import OptimizerConfig

    if model_id in LEARNING_RATE_OVERRIDES:
        lr = LEARNING_RATE_OVERRIDES[model_id]
        logger.info(f"Using learning rate {lr} override for model {model_id}")
    else:
        lr = DEFAULT_LEARNING_RATE

    # We found that these values give the fastest convergence
    # for memorizing a 512 token string with 26 characters alphabet
    # by a 1bn Pythia model.
    # We compared to learning rates [1e-4, 1e-5, 5e-6, 1e-6, 5e-7]
    # and learning rate schedules [linear, cos, const].
    return OptimizerConfig(
        learning_rate=lr,
        schedule="lin",
        warmup=0.05,
    )


def get_sequence_metrics(
    tokenizer: PreTrainedTokenizer,
    data_config: RandomStringConfig,
    alphabet: str,
    target_strings: list[str],
) -> dict:
    """Construct a metric per character that tracks the charcter's probability
    distribution over the sequence.
    """
    from lib_llm.eval.sequences import (
        CorrectnessMetric,
        LossMetric,
        SequenceMetric,
        SequenceProbMetric,
    )

    from ..encoding import encode_strings_characterwise

    first_string_length = len(target_strings[0])
    assert all(len(s) == first_string_length for s in target_strings)
    strings_encoding = encode_strings_characterwise(tokenizer, target_strings)
    metrics: dict[str, SequenceMetric] = {
        "loss": LossMetric(strings_encoding),
        "correct": CorrectnessMetric(strings_encoding),
    }

    # Add the character metrics that track the probabilty distribution
    # for each character
    for char in alphabet:
        char_encoding = tokenizer.encode(char, add_special_tokens=False)
        # Sequence probability estimation is done by feeding token sequences
        # with the same length as the sequence of token probabilities
        # produced by the model for the input sequence.
        # Therefore, for each cahracter whose token's probability we want to
        # compute, we create a sequence of the same length as the
        # input/training sequence consisting only of that character's token.
        # E.g. for the 26 latin chracters, we estimate the probabilities
        # of the sequences {"a": ["a", "a", ...], "b": ["b", "b", ...], ...}.
        token_ids = torch.tensor(
            char_encoding
            * (data_config.num_tokens // data_config.num_partitions)
        )
        token_mask = torch.ones_like(token_ids)

        metrics[char] = SequenceProbMetric(
            BatchEncoding(
                dict(
                    input_ids=token_ids,
                    attention_mask=token_mask,
                )
            )
        )
    return metrics


def postprocess_memorization_log(
    eval_result: pd.DataFrame,
) -> dict[str, pd.DataFrame]:
    loss_log = cast(pd.DataFrame, eval_result[["loss"]])
    memorization_log = eval_result.drop(columns=["loss"])
    memorization_log.index.names = ["epoch", "string"]
    characterwise_mem_log = _match_token_probs_to_sequence_pos(memorization_log)
    return {
        "loss": loss_log,
        "memorization": characterwise_mem_log,
    }


def _match_token_probs_to_sequence_pos(
    token_distributions: pd.DataFrame,
) -> pd.DataFrame:
    """Out of the box the character probability metrics store a list
    of probabilities in the dataframe cells for each tracked character,
    with an entry per sequence positions.
    This function converts the lists stored in each cell to an additional
    dataframe column, such that each cell only stores a scalar value.
    """
    return pd.concat(
        {
            epoch: pd.concat(
                [
                    pd.DataFrame(
                        {
                            char: char_probs
                            for char, char_probs in zip(
                                sequence_df.columns,
                                sequence_df.values.flatten(),
                            )
                        },
                        index=pd.Index(
                            list(str(sequence)[1:]),
                            name="character",
                        ),
                        dtype="float",
                    )
                    for sequence, sequence_df in epoch_df.droplevel(
                        "epoch"
                    ).groupby("string")
                ],
                axis=0,
                keys=epoch_df.index.get_level_values("string"),
            )
            for epoch, epoch_df in token_distributions.groupby("epoch")
        },
        axis=0,
        keys=token_distributions.index.get_level_values("epoch"),
    )
