from dataclasses import dataclass
from typing import Optional

from defs import ARTIFACTS_DIR, BaseConfigArgs
from lib_dl.analysis.experiment import ExperimentHandle

from .experiment import (
    EXP_ABBREVIATION,
    ExperimentConfig,
    FinetuningConfig,
    RandomStringConfig,
    mt_experiment,
)


NUM_EPOCHS = 100


@dataclass
class ConfigArgs(BaseConfigArgs):
    alphabet_size: int
    num_tokens: int
    num_partitions: int = 1
    first_character_prob: float | None = None
    pretrained: bool = True
    save_model: bool = False


SINGLE_STRING_ARGS = {
    f"{model_id}_a-{alphabet_size}_t-{num_tokens}": ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=num_tokens,
    )
    for model_id in [
        "pyt-70m",
        "pyt-160m",
        "pyt-410m",
        "pyt-1b",
        "pyt-1.4b",
        "pyt-2.8b",
        "pyt-6.9b",
        "pyt-12b",
        "llama2-7b",
        "llama2-13b",
        "gpt2",
        "gpt2-xl",
    ]
    for alphabet_size in [2, 4, 7, 13, 26]
    for num_tokens in [16, 32, 64, 128, 256, 512, 1024]
}
MULTI_STRING_ARGS = {
    f"{model_id}_a-{alphabet_size}_t-{num_tokens}_p-{num_partitions}": ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=num_tokens,
        num_partitions=num_partitions,
        # We save the final checkpoint to evaluate it on the
        # full string afterwards.
        save_model=True,
    )
    for model_id in [
        "pyt-1b",
    ]
    for alphabet_size in [26]
    for num_tokens in [1024]
    for num_partitions in [2, 4, 8, 16, 32, 64]
}
ENTROPY_CONTROL_ARGS = {
    f"{model_id}_a-26_t-{num_tokens}_eea-{equivalent_alphabet_size}": ConfigArgs(
        model_id=model_id,
        alphabet_size=26,
        num_tokens=num_tokens,
        first_character_prob=first_char_prob,
    )
    for model_id in [
        "pyt-1b",
    ]
    for num_tokens in [1024]
    # Probability values for the first character, that reduce the entropy
    # of the 26 character alphabet down to that of uniform lower alphabet
    # size ones.
    # Computed via numerical approximation.
    for equivalent_alphabet_size, first_char_prob in {
        13: 0.41385763230705264,
        7: 0.6040332620868685,
        4: 0.7455315447635649,
        2: 0.8913995544185638,
    }.items()
}


CONFIG_ARGS = {
    "test": ConfigArgs(
        model_id="pyt-70m",
        num_tokens=64,
        num_partitions=1,
        alphabet_size=7,
    ),
} | (SINGLE_STRING_ARGS | MULTI_STRING_ARGS | ENTROPY_CONTROL_ARGS)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        local_rank=-1,
        data=RandomStringConfig(
            seed_id=-1,
            alphabet_size=args.alphabet_size,
            num_tokens=args.num_tokens,
            num_partitions=args.num_partitions,
            first_char_prob=args.first_character_prob,
            artifacts_dir=ARTIFACTS_DIR,
        ),
        fine_tuning=FinetuningConfig(
            seed_id=-1,
            model_id=args.model_name,
            epochs=NUM_EPOCHS,
            base_model_dir=args.model_dir,
            save_model=args.save_model,
        ),
    )
    if eval_type == "test":
        config.fine_tuning.epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    config.seed_id = seed_id
    config.data.seed_id = seed_id
    config.fine_tuning.seed_id = seed_id
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


MTHandle = ExperimentHandle(
    id=EXP_ABBREVIATION,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=mt_experiment,
)
