import logging
from dataclasses import dataclass, field

import pandas as pd
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, TrainerCallback

from data.synthetic_strings import deunion_data, get_config
from data.synthetic_strings.deterministic_rules import (
    DeterministicRuleStringConfig,
    DeterministicRuleStringData,
)
from data.synthetic_strings.random import RandomStringConfig, RandomStringData
from defs import LLMExperimentConfig
from lib_llm.eval.metrics import TokenEvaluationTask
from lib_llm.models import ModelConfig
from lib_project.experiment import ExperimentID, NoSave, NoSaveValue, experiment
from utils.memorization import MemorizationConfig, get_memorized_model_tokenizer


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "memorization_dynamics"
EXP_ABBREVIATION = "md"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    random_data: RandomStringConfig | None
    deterministic_rule_data: DeterministicRuleStringConfig | None
    model: ModelConfig
    memorization: MemorizationConfig


@dataclass
class ExperimentResult:
    random_data: RandomStringData | None
    deterministic_rule_data: DeterministicRuleStringData | None
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame
    model: NoSave[PreTrainedModel]
    tokenizer: NoSave[PreTrainedTokenizer]


@experiment(EXP_NAME)
def md_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
    override_memorization_eval_task: TokenEvaluationTask | None = None,
    additional_callbacks: list[TrainerCallback] = [],
) -> ExperimentResult:
    model_name = config.model.name
    if config.memorization.freeze is not None:
        model_name += config.memorization.freeze.id_postfix
    data_config = get_config(config.random_data, config.deterministic_rule_data)
    experiment_id.config_name = [
        data_config.name,
        # data_config.sid,
        model_name,
    ]

    ft_res = get_memorized_model_tokenizer(
        config.memorization,
        data_config,
        config.model,
        experiment_id,
        override_eval_task=override_memorization_eval_task,
        additional_callbacks=additional_callbacks,
    )
    print(ft_res.memorization_log)

    return ExperimentResult(
        **deunion_data(ft_res.data),
        training_log=ft_res.training_log,
        memorization_log=ft_res.memorization_log,
        model=NoSaveValue(ft_res.model),
        tokenizer=NoSaveValue(ft_res.tokenizer),
    )
