from dataclasses import dataclass, replace
from typing import Optional

from defs import EXPERIMENT_SEEDS, BaseConfigArgs
from lib_llm.models import ModelConfig, get_tokenizer_type
from lib_project.experiment import ExperimentHandle
from utils.memorization import get_memorization_training_config

from .experiment import (
    EXP_ABBREVIATION,
    ExperimentConfig,
    MemorizationConfig,
    RandomStringConfig,
    icef_experiment,
)


@dataclass
class ConfigArgs(BaseConfigArgs):
    alphabet_size: int
    additional_alphabet_sizes: list[int]
    num_tokens: int
    num_epochs: int = 100


SINGLE_STRING_ARGS = {
    f"{model_id}_a-{alphabet_size}_t-{num_tokens}": ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=num_tokens,
        additional_alphabet_sizes=[2, 26],
    )
    for model_id in [
        "pyt-1b",
    ]
    for alphabet_size in [7]
    for num_tokens in [16, 32, 64, 128, 256, 512, 1024]
}


CONFIG_ARGS = {
    "test": ConfigArgs(
        model_id="pyt-70m",
        num_tokens=64,
        alphabet_size=7,
        additional_alphabet_sizes=[2, 26],
    ),
} | (SINGLE_STRING_ARGS)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    training_data_config = RandomStringConfig(
        seed_id=-1,
        alphabet="latin",
        alphabet_size=args.alphabet_size,
        num_tokens=args.num_tokens,
        num_partitions=1,
        tokenizer_type=get_tokenizer_type(args.model_name),
    )
    model_config = ModelConfig(
        model_id=args.model_name,
        base_dir=args.model_dir,
    )
    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        training_data=training_data_config,
        additional_eval_data=[
            replace(training_data_config, alphabet_size=alphabet_size)
            for alphabet_size in args.additional_alphabet_sizes
        ],
        model=model_config,
        memorization=MemorizationConfig(
            training=get_memorization_training_config(
                seed_id=-1,
                model_id=model_config.model_id_not_none,
                num_epochs=args.num_epochs,
                wandb_project_name=f"llms_{EXP_ABBREVIATION}",
            ),
        ),
    )
    if eval_type == "test":
        config.memorization.training.args.num_train_epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    config.seed_id = seed_id
    seed = EXPERIMENT_SEEDS[seed_id]

    config.training_data.seed_id = seed_id
    for eval_data in config.additional_eval_data:
        eval_data.seed_id = seed_id
    config.memorization.training.args.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


ICEFHandle = ExperimentHandle(
    id=EXP_ABBREVIATION,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=icef_experiment,
)
