import logging
from dataclasses import dataclass

import numpy as np
import pandas as pd
from transformers import PreTrainedTokenizer

from data.synthetic_strings.random import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train
from lib_project.experiment import ExperimentID, experiment


logger = logging.getLogger(__name__)

EXP_NAME = "icl_memorization_relationship"
EXP_ABBREVIATION = "imr"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: RandomStringConfig
    model: ModelConfig
    training: TrainingConfig


@dataclass
class ExperimentResult:
    data: RandomStringData
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame


@experiment(EXP_NAME)
def irm_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(config.model)
    data = generate_random_strings(config.data, tokenizer)
    dataset = data.dataset()

    eval_task = memorization_dynamics_metrics(
        data.alphabet_tokens,
        data.alphabet_token_ids,
        dataset["test"],
    )
    training_res = train(
        experiment_id,
        (config.model.model_id_not_none, model),
        (config.data.name, dataset),
        config=config.training,
        tokenizer=tokenizer,
        callbacks=[eval_task],
        data_already_preprocessed=True,
    )

    memorization_log = eval_task.result()
    trainig_log = training_res.training_log
    assert trainig_log is not None
    return ExperimentResult(
        data=data,
        training_log=trainig_log,
        memorization_log=memorization_log,
    )
