import logging
from dataclasses import dataclass

import pandas as pd
import torch

from data.synthetic_strings.random import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train
from lib_project.experiment import ExperimentID, experiment


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "hyperparameter_search"
EXP_ABBREVIATION = "hs"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    model: ModelConfig
    data: RandomStringConfig
    training: TrainingConfig


@dataclass
class ExperimentResult:
    data: RandomStringData
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame


@experiment(EXP_NAME)
def hs_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(
        config.model,
    )
    data = generate_random_strings(config.data, tokenizer)
    dataset = data.dataset()

    eval_task = memorization_dynamics_metrics(
        data.alphabet_tokens,
        data.alphabet_token_ids,
        dataset["test"],
    )

    training_res = train(
        experiment_id,
        (config.model.model_id_not_none, model),
        (config.data.name, dataset),
        config=config.training,
        tokenizer=tokenizer,
        callbacks=[eval_task],
        data_already_preprocessed=True,
    )

    memorization_log = eval_task.result()
    training_logs = training_res.training_log
    assert training_logs is not None
    return ExperimentResult(
        data=data,
        training_log=training_logs,
        memorization_log=memorization_log,
    )
