from dataclasses import dataclass
from typing import Optional

from defs import ARTIFACTS_DIR, BaseConfigArgs
from lib_dl.analysis.experiment import ExperimentHandle

from .experiment import (
    EXP_ABBR,
    ExperimentConfig,
    FinetuningConfig,
    PrefixEvalConfig,
    RandomStringConfig,
    pp_experiment,
)


SEEDS = [
    5932,
    4152,
    4967,
    2938,
    84163,
    2663,
    27,
    8206,
    1625,
    6232,
]


@dataclass
class ConfigArgs(BaseConfigArgs):
    sequence_length: int
    num_epochs: int = 100
    pretrained: bool = True
    alphabet_size: int = 26
    size_change: float = 1.0
    replacement_strategy: str = "rand_id"


SEQUENCE_LENGTH_ALPHABET_SIZE_ARGS = {
    f"{model_id}_sl-{seq_length}_al-{alphabet_size}": ConfigArgs(
        model_id=model_id,
        sequence_length=seq_length,
        alphabet_size=alphabet_size,
    )
    for model_id in [
        "pyt-70m",
        "pyt-1b",
        "pyt-12b",
        "llama2-7b",
        "llama2-13b",
        "gpt2",
        "gpt2-xl",
    ]
    for seq_length in [16, 32, 64, 128, 256, 512, 1024]
    for alphabet_size in [2, 4, 7, 13, 26]
}
SIZE_CHANGE_ARGS = {
    f"{model_id}_sl-{seq_length}_sc-{size_change}": ConfigArgs(
        model_id=model_id,
        sequence_length=seq_length,
        size_change=size_change,
    )
    for model_id in ["pyt-1b"]
    for seq_length in [64, 128, 256]
    for size_change in [0, 0.25, 0.5, 0.75, 1.25, 1.5, 1.75, 2]
}
TRAING_STAGE_ARGS = {
    f"{model_id}_sl-{seq_length}_te-{training_epochs}": ConfigArgs(
        model_id=model_id,
        sequence_length=seq_length,
        num_epochs=training_epochs,
    )
    for model_id in ["pyt-1b"]
    for seq_length in [64, 128, 256, 512, 1024]
    for training_epochs in [5, 10, 15, 20, 30, 40]
}
REPLACEMENT_STRATEGY_ARGS = {
    f"{model_id}_sl-{seq_length}_rs-{replacement_strategy}": ConfigArgs(
        model_id=model_id,
        sequence_length=seq_length,
        replacement_strategy=replacement_strategy,
    )
    for model_id in ["pyt-1b"]
    for seq_length in [64, 128, 256, 512, 1024]
    for replacement_strategy in [
        "rand_id",
        "const_id",
        "rand_ood",
        "const_ood",
    ]
}


CONFIG_ARGS = {
    "test": ConfigArgs(
        model_id="pyt-70m",
        sequence_length=64,
        num_epochs=1,
    ),
} | (
    SEQUENCE_LENGTH_ALPHABET_SIZE_ARGS
    | SIZE_CHANGE_ARGS
    | TRAING_STAGE_ARGS
    | REPLACEMENT_STRATEGY_ARGS
)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        local_rank=-1,
        data=RandomStringConfig(
            seed_id=-1,
            num_tokens=args.sequence_length,
            alphabet_size=args.alphabet_size,
            artifacts_dir=ARTIFACTS_DIR,
        ),
        fine_tuning=FinetuningConfig(
            seed_id=-1,
            model_id=args.model_name,
            epochs=args.num_epochs,
            base_model_dir=args.model_dir,
        ),
        prefix_testing=PrefixEvalConfig(
            seed=-1,
            num_samples_per_prefix=100,
            relative_non_prefix_size=args.size_change,
            replacement_strategy=args.replacement_strategy,
        ),
    )
    if eval_type == "test":
        config.fine_tuning.epochs = 1
        config.prefix_testing.num_samples_per_prefix = 2
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    seed = SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed = seed
    config.data.seed_id = seed_id
    config.fine_tuning.seed_id = seed_id
    config.prefix_testing.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


PPHandle = ExperimentHandle(
    id=EXP_ABBR,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=pp_experiment,
)
