import logging
import os
from dataclasses import dataclass

import pandas as pd
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, TrainerCallback

from data.synthetic_strings.conditional_random import (
    ConditionalRandomStringConfig,
    create_ngram_conditional_dataset,
)
from data.synthetic_strings.random import RandomStringData
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training.train import TrainingConfig, train
from lib_project.experiment import ExperimentID, NoSave, NoSaveValue, experiment
from utils.memorization.memorization import MemorizationTrainingResult


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()
NUM_DEVICES = torch.cuda.device_count() if HAS_CUDA else 1
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "-1"))

EXP_NAME = "conditional_prob_mem_dynamics"
EXP_ABBREVIATION = "cpmd"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: ConditionalRandomStringConfig
    model: ModelConfig
    training: TrainingConfig


@dataclass
class ExperimentResult:
    data: RandomStringData
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame
    model: NoSave[PreTrainedModel]
    tokenizer: NoSave[PreTrainedTokenizer]


@experiment(EXP_NAME)
def cpmd_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(config.model)
    tokenizer.model_max_length = 1024
    data = create_ngram_conditional_dataset(config.data, tokenizer)

    training_res = train_model(
        experiment_id,
        config.training,
        (config.model.model_id_not_none, model),
        tokenizer,
        data,
    )
    training_log = training_res.training_log
    assert training_log is not None
    memorization_log = training_res.memorization_log
    assert memorization_log is not None

    return ExperimentResult(
        data=data,
        training_log=training_res.training_log,
        memorization_log=training_res.memorization_log,
        model=NoSaveValue(training_res.model),
        tokenizer=NoSaveValue(training_res.tokenizer),
    )


def train_model(
    experiment_id: ExperimentID,
    training_config: TrainingConfig,
    model_info: tuple[str, PreTrainedModel],
    tokenizer: PreTrainedTokenizer,
    data: RandomStringData,
) -> MemorizationTrainingResult:
    dataset = data.dataset()
    memorization_task = memorization_dynamics_metrics(
        data.alphabet_tokens,
        data.alphabet_token_ids,
        dataset["test"],
    )
    eval_tasks: list[TrainerCallback] = [memorization_task]

    training_res = train(
        experiment_id,
        model_info,
        (data.config.name, dataset),
        config=training_config,
        tokenizer=tokenizer,
        callbacks=eval_tasks,
        data_already_preprocessed=True,
    )
    training_log = training_res.training_log
    assert training_log is not None
    memorization_log = memorization_task.result()
    return MemorizationTrainingResult(
        model=training_res.model,
        tokenizer=tokenizer,
        data=data,
        training_log=training_log,
        memorization_log=memorization_log,
    )
