import logging
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path

import pandas as pd
import torch

from defs import LLMExperimentConfig
from lib_dl.analysis.experiment import ExperimentTaskDescription, experiment
from utils.data.random_strings import RandomStringConfig, get_random_strings
from utils.finetuning.finetune import (
    FinetuningConfig,
    LogType,
    get_finetuned_model_tokenizer,
    load_log,
)


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "model_training"
EXP_ABBREVIATION = "mt"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: RandomStringConfig
    fine_tuning: FinetuningConfig


@dataclass
class ExperimentResult:
    strings: list[str]
    storage_sub_dir: Path
    base_storage_dir: Path | None = None

    @cached_property
    def training_log(self) -> pd.DataFrame:
        return self._get_log("training")

    @cached_property
    def loss_log(self) -> pd.DataFrame:
        return self._get_log("loss")

    @cached_property
    def memorization_log(self) -> pd.DataFrame:
        return self._get_log("memorization")

    def _get_log(self, log_type: LogType) -> pd.DataFrame:
        if self.base_storage_dir is None:
            raise ValueError(
                "base_storage_dir must be set to access training log"
            )
        storage_dir = self.base_storage_dir / self.storage_sub_dir
        return load_log(storage_dir, log_type)


@experiment(EXP_NAME)
def mt_experiment(
    config: ExperimentConfig,
    description: ExperimentTaskDescription,
) -> ExperimentResult:
    data = get_random_strings(config.data)

    ft_res = get_finetuned_model_tokenizer(
        config.fine_tuning,
        config.data,
        local_rank=config.local_rank,
    )
    print(ft_res.loss_log)
    print(ft_res.memorization_log)

    return ExperimentResult(
        strings=data.strings,
        storage_sub_dir=ft_res.storage_sub_dir,
    )
