from dataclasses import dataclass
from typing import Optional

from defs import EXPERIMENT_SEEDS, BaseConfigArgs
from lib_llm.models import ModelConfig, get_tokenizer_type
from lib_project.experiment import ExperimentHandle
from utils.memorization import get_memorization_training_config

from .experiment import (
    EXP_ABBREVIATION,
    ExperimentConfig,
    RandomStringConfig,
    SubstringConfig,
    sss_experiment,
)


@dataclass
class ConfigArgs(BaseConfigArgs):
    alphabet_size: int
    string_length: int
    substring_length: int
    num_distinct_substrings: int
    placement_order: str = "random"
    entropy_target: int | None = None
    num_epochs: int = 100


PLACEMENT_ORDERS = {
    (
        f"{model_id}_a-{alphabet_size}_sl-{substring_length}_ns-{num_distinct_substrings}_plo-{placement_order}"
    ): ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        string_length=1024,
        substring_length=substring_length,
        num_distinct_substrings=num_distinct_substrings,
        placement_order=placement_order,
    )
    for model_id in [
        "pyt-1b",
        "phi-2.7b",
        "llama2-13b",
    ]
    for alphabet_size in [2, 7, 26]
    for placement_order in ["consecutive", "iterative"]
    for substring_length in [2, 4, 8, 16, 32, 64, 128, 256, 512]
    for num_distinct_substrings in [1, 8, 16, 32, 64]
}
ALPHABET_SIZES = {
    (
        f"{model_id}_a-{alphabet_size}_sl-{substring_length}_ns-{num_distinct_substrings}"
    ): ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        string_length=1024,
        substring_length=substring_length,
        num_distinct_substrings=num_distinct_substrings,
    )
    for model_id in [
        "pyt-1b",
    ]
    for alphabet_size in [2, 4, 7, 13, 26]
    for substring_length in [2, 4, 8, 16, 32, 64]
    for num_distinct_substrings in [8, 16, 32, 64, 128]
}
ENTROPIES = {
    (
        f"{model_id}_h-{entropy_target}_sl-{substring_length}_ns-{num_distinct_substrings}"
    ): ConfigArgs(
        model_id=model_id,
        alphabet_size=26,
        entropy_target=entropy_target,
        string_length=1024,
        substring_length=substring_length,
        num_distinct_substrings=num_distinct_substrings,
    )
    for model_id in [
        "pyt-1b",
    ]
    for entropy_target in [2, 4, 7, 13]
    for substring_length in [2, 4, 8, 16, 32]
    for num_distinct_substrings in [8, 16, 32, 64]
}


CONFIG_ARGS = (
    {
        "test": ConfigArgs(
            model_id="pyt-70m",
            string_length=32,
            alphabet_size=7,
            substring_length=4,
            num_distinct_substrings=4,
        ),
    }
    | PLACEMENT_ORDERS
    | ALPHABET_SIZES
    | ENTROPIES
)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]
    model_config = ModelConfig(
        model_id=args.model_name,
        base_dir=args.model_dir,
        pretrained=True,
    )
    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        data=RandomStringConfig(
            seed_id=-1,
            alphabet="latin",
            alphabet_size=args.alphabet_size,
            num_tokens=args.string_length,
            entropy_like=args.entropy_target,
            tokenizer_type=get_tokenizer_type(args.model_name),
        ),
        substrings=SubstringConfig(
            seed=-1,
            length=args.substring_length,
            num_distinct_substrings=args.num_distinct_substrings,
            placement_order=args.placement_order,
        ),
        model=model_config,
        training=get_memorization_training_config(
            seed_id=-1,
            model_id=model_config.model_id_not_none,
            num_epochs=args.num_epochs,
            wandb_project_name=f"llm_mem_{EXP_ABBREVIATION}",
        ),
    )
    if eval_type == "test":
        config.training.args.num_train_epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    seed = EXPERIMENT_SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed
    config.data.seed_id = seed_id
    config.substrings.seed = seed
    config.training.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


RSHandle = ExperimentHandle(
    id=EXP_ABBREVIATION,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=sss_experiment,
)
