from dataclasses import dataclass
from typing import Optional

from defs import EXPERIMENT_SEEDS, BaseConfigArgs
from lib_llm.models import ModelConfig, get_tokenizer_type
from lib_project.experiment import ExperimentHandle
from utils.memorization import FreezeConfig, get_memorization_training_config

from .experiment import (
    EXP_ABBREVIATION,
    ExperimentConfig,
    MemorizationConfig,
    RandomStringConfig,
    md_experiment,
)


@dataclass
class ConfigArgs(BaseConfigArgs):
    alphabet_size: int
    num_tokens: int
    alphabet: str = "latin"
    num_partitions: int = 1
    entropy_target: int | None = None
    # first_character_prob: float | None = None
    num_epochs: int = 100
    pretrained: bool = True
    save_model: bool = False
    freeze: FreezeConfig | None = None
    wandb_project_name: str = EXP_ABBREVIATION


SINGLE_STRING_ARGS = {
    f"{model_id}_a-{alphabet_size}_t-{num_tokens}": ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=num_tokens,
        # save_model=(
        #     model_id == "pyt-1b" and alphabet_size == 26 and num_tokens == 1024
        # ),
        save_model=False,
    )
    for model_id in [
        "pyt-70m",
        "pyt-160m",
        "pyt-410m",
        "pyt-1b",
        "pyt-1.4b",
        "pyt-2.8b",
        "pyt-6.9b",
        "pyt-12b",
        "llama2-7b",
        "llama2-13b",
        "gpt2-124m",
        "gpt2-1.5b",
        # "phi-1",
        "phi-1.3b",
        "phi-2.7b",
        "opt-350m",
    ]
    for alphabet_size in [2, 4, 7, 13, 26]
    for num_tokens in [16, 32, 64, 128, 256, 512, 1024]
}
ENTROPY_CONTROL_ARGS = {
    f"{model_id}_h-{entropy_target}_t-{num_tokens}": ConfigArgs(
        model_id=model_id,
        alphabet_size=26,
        num_tokens=num_tokens,
        entropy_target=entropy_target,
    )
    for model_id in [
        "pyt-70m",
        "pyt-1b",
        "pyt-12b",
        "llama2-7b",
        "llama2-13b",
        "gpt2-124m",
        "gpt2-1.5b",
        "phi-1.3b",
        "phi-2.7b",
        "opt-350m",
    ]
    for num_tokens in [1024]
    # Probability values for the first character, that reduce the entropy
    # of the 26 character alphabet down to that of uniform lower alphabet
    # size ones.
    # Computed via numerical approximation.
    for entropy_target in [13, 7, 4, 2]
}
MULTI_STRING_ARGS = {
    f"{model_id}_a-{alphabet_size}_t-{num_tokens}_p-{num_partitions}": ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=num_tokens,
        num_partitions=num_partitions,
        # We save the final checkpoint to evaluate it on the
        # full string afterwards.
        save_model=False,
    )
    for model_id in [
        "pyt-1b",
        "llama2-13b",
        "phi-2.7b",
    ]
    for alphabet_size in [2, 26]
    for num_tokens in [1024]
    for num_partitions in [2, 4, 8, 16, 32, 64]
}
ALPHABET_TYPE_ARGS = {
    f"{model_id}_a-{alphabet_id}-{alphabet_size}": ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        alphabet=alphabet,
        num_tokens=1024,
    )
    for model_id in [
        "pyt-70m",
        "pyt-1b",
        "phi-1.3b",
        "phi-2.7b",
        "llama2-13b",
    ]
    for alphabet_id, alphabet in [
        ("num", "numeric"),
        ("nlat", "non_latin"),
    ]
    for alphabet_size in [2, 4, 7, 13, 26]
}
UNTRAINED_ARGS = {
    f"{model_id}_a-{alphabet_size}_ut": ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=1024,
        pretrained=False,
        num_epochs=300,
    )
    for model_id in [
        "pyt-1b",
        "llama2-13b",
        "phi-2.7b",
    ]
    for alphabet_size in [2, 4, 7, 13, 26]
}
FREEZE_ARGS = {
    f"{model_id}_f-{freeze}": ConfigArgs(
        model_id=model_id,
        alphabet_size=26,
        num_tokens=1024,
        num_epochs=200,
        wandb_project_name="freezing",
        freeze=(
            setattr(FreezeConfig(), freeze, True)
            if not freeze.startswith("layer_")
            else FreezeConfig(layer_x_only=freeze[len("layer_") :])
        ),
    )
    for model_id in [
        "pyt-1b",
    ]
    for freeze in [
        "bias",
        "attention_only",
        "mlp_only",
        "layernorm_only",
        "embed_in_only",
        "embed_out_only",
        "layer_7",
    ]
}


CONFIG_ARGS = {
    "test": ConfigArgs(
        model_id="pyt-70m",
        # model_id="pyt-1.4b",
        # model_id="llama2-7b",
        num_tokens=64,
        num_partitions=2,
        alphabet_size=7,
    ),
} | (
    SINGLE_STRING_ARGS
    | MULTI_STRING_ARGS
    | ENTROPY_CONTROL_ARGS
    | ALPHABET_TYPE_ARGS
    | UNTRAINED_ARGS
    | FREEZE_ARGS
)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    model_config = ModelConfig(
        model_id=args.model_name,
        base_dir=args.model_dir,
        pretrained=args.pretrained,
    )
    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        random_data=RandomStringConfig(
            seed_id=-1,
            alphabet=args.alphabet,
            alphabet_size=args.alphabet_size,
            num_tokens=args.num_tokens,
            num_partitions=args.num_partitions,
            entropy_like=args.entropy_target,
            tokenizer_type=get_tokenizer_type(args.model_name),
        ),
        deterministic_rule_data=None,
        model=model_config,
        memorization=MemorizationConfig(
            training=get_memorization_training_config(
                seed_id=-1,
                model_id=model_config.model_id_not_none,
                num_epochs=args.num_epochs,
                batch_size=args.num_partitions,
                wandb_project_name=f"llm_mem_{args.wandb_project_name}",
                save_final_checkpoint=args.save_model,
                is_pretrained=args.pretrained,
            ),
            freeze=args.freeze,
        ),
    )
    if eval_type == "test":
        config.memorization.training.args.num_train_epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    config.seed_id = seed_id
    seed = EXPERIMENT_SEEDS[seed_id]

    if config.random_data is not None:
        config.random_data.seed_id = seed_id
    if config.deterministic_rule_data is not None:
        config.deterministic_rule_data.seed_id = seed_id
    config.memorization.training.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


MDHandle = ExperimentHandle(
    id=EXP_ABBREVIATION,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=md_experiment,
)
