from dataclasses import dataclass
from typing import Optional

from data.synthetic_strings.utils import SEEDS as STRING_SEEDS
from defs import EXPERIMENT_SEEDS, SEED_OFFSET, BaseConfigArgs
from lib_llm.models import ModelConfig, get_tokenizer_type
from lib_project.experiment import ExperimentHandle
from utils.memorization import get_memorization_training_config

from .experiment import (
    EXP_ABBREVIATION,
    ExperimentConfig,
    PrefixEvalConfig,
    RandomStringConfig,
    rt_experiment,
)


@dataclass
class ConfigArgs(BaseConfigArgs):
    num_tokens: list[int]
    alphabet_size: list[int]
    seed_idxs: list[int]
    pretrained: bool = True
    num_epochs: int = 50
    prefix_eval_epochs: list[int] | None = None
    save_after_iterations: list[int] | None = None


# We optionally invert seeds, to ensure that training in increasing
# or decreasing order uses the same strings, and the order is the
# only difference.
def _seed_indices(num_strings: int, invert_seeds: bool) -> list[int]:
    return list(
        (
            range(num_strings)
            if not invert_seeds
            else range(num_strings - 1, -1, -1)
        )
    )


SAME_CONFIG = {
    f"{model_id}_a-{alphabet_size}_t-{num_tokens}_x{num_strings}": ConfigArgs(
        model_id=model_id,
        alphabet_size=[alphabet_size] * num_strings,
        num_tokens=[num_tokens] * num_strings,
        seed_idxs=list(range(num_strings)),
        num_epochs=50,
        # prefix_eval_epochs=[5, 10, 15, 20, 30, 40, 50],
        # save_after_iterations=[num_strings - 2, num_strings - 1],
    )
    for model_id in [
        "pyt-1b",
        "phi-2.7b",
        "llama2-13b",
    ]
    # for alphabet_size in [2, 4, 7, 13, 26]
    for alphabet_size in [2, 26]
    for num_tokens in [1024]
    for num_strings in [8, 16, 24, 32]
}
UNTRAINED_CONFIG = {
    f"{model_id}-ut_a-{alphabet_size}_x{num_strings}": ConfigArgs(
        model_id=model_id,
        alphabet_size=[alphabet_size] * num_strings,
        num_tokens=[num_tokens] * num_strings,
        seed_idxs=list(range(num_strings)),
        pretrained=False,
        num_epochs=50,
    )
    for model_id in [
        "pyt-1b",
        "phi-2.7b",
        "llama2-13b",
    ]
    for alphabet_size in [2, 26]
    for num_tokens in [1024]
    for num_strings in [16, 32]
}
# Train for a long time on a single string, to compare it to training on
# multiple different strings.
LONG_TRAINING = {
    f"{model_id}_a-{alphabet_size}_t-{num_tokens}_s-{string_idx}": ConfigArgs(
        model_id=model_id,
        alphabet_size=[alphabet_size],
        num_tokens=[num_tokens],
        seed_idxs=[string_idx],
        num_epochs=500,
        prefix_eval_epochs=[50, 100, 150, 200, 500],
    )
    for model_id in [
        "pyt-1b",
    ]
    for alphabet_size in [2]
    for num_tokens in [1024]
    for string_idx in [0, 26, 31]
}
TOKEN_AMOUNT_CHANGE = {
    f"{model_id}_a-{alphabet_size}_t-{change_name}": ConfigArgs(
        model_id=model_id,
        alphabet_size=[alphabet_size] * len(num_tokens),
        num_tokens=num_tokens,
        seed_idxs=_seed_indices(len(num_tokens), invert_seeds),
    )
    for model_id in [
        "pyt-1b",
    ]
    for alphabet_size in [2, 4, 7, 13, 26]
    for change_name, num_tokens, invert_seeds in [
        ("inc", [8, 16, 32, 64, 128, 256, 512, 1024], False),
        ("dec", [1024, 512, 256, 128, 64, 32, 16, 8], True),
    ]
}
ALPHABET_SIZE_CHANGE = {
    f"{model_id}_a-{alphabet_name}_t-1024": ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_sizes,
        num_tokens=[1024] * len(alphabet_sizes),
        seed_idxs=_seed_indices(len(alphabet_sizes), invert_seeds),
    )
    for model_id in [
        "pyt-1b",
    ]
    for alphabet_name, alphabet_sizes, invert_seeds in [
        ("inc", [2, 4, 7, 13, 26], False),
        ("dec", [26, 13, 7, 4, 2], True),
    ]
}

CONFIG_ARGS = {
    "test": ConfigArgs(
        model_id="pyt-70m",
        num_tokens=[32, 64],
        alphabet_size=[7, 13],
        seed_idxs=[0, 1],
        prefix_eval_epochs=[1],
        num_epochs=1,
    ),
} | (
    SAME_CONFIG
    | UNTRAINED_CONFIG
    | LONG_TRAINING
    | TOKEN_AMOUNT_CHANGE
    | ALPHABET_SIZE_CHANGE
)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    model_config = ModelConfig(
        model_id=args.model_name,
        base_dir=args.model_dir,
        pretrained=args.pretrained,
    )
    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        data=[
            RandomStringConfig(
                seed=seed_idx * SEED_OFFSET,
                alphabet="latin",
                alphabet_size=alphabet_size,
                num_tokens=num_tokens,
                tokenizer_type=get_tokenizer_type(args.model_name),
            )
            for alphabet_size, num_tokens, seed_idx in zip(
                args.alphabet_size, args.num_tokens, args.seed_idxs
            )
        ],
        model=model_config,
        training=get_memorization_training_config(
            seed_id=-1,
            model_id=model_config.model_id_not_none,
            num_epochs=args.num_epochs,
            batch_size=1,
            wandb_project_name=f"llm_mem_{EXP_ABBREVIATION}",
            save_final_checkpoint=False,
            use_wandb=False,
            is_pretrained=args.pretrained,
        ),
        save_after_iterations=args.save_after_iterations,
        prefix_eval=(
            PrefixEvalConfig(
                seed=-1,
                num_samples_per_prefix=10,
                # dense_prefix_sampling_end=16,
                max_token_samples=256,
            )
            if args.prefix_eval_epochs is not None
            else None
        ),
        prefix_eval_epochs=args.prefix_eval_epochs,
    )
    if eval_type == "test":
        # config.training.args.num_train_epochs = 1
        pass
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    seed = EXPERIMENT_SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed
    config.training.seed = seed

    string_base_seed = STRING_SEEDS[seed_id]
    for data in config.data:
        assert data.seed is not None
        data.seed += string_base_seed

    if config.prefix_eval is not None:
        config.prefix_eval.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


RTHandle = ExperimentHandle(
    id=EXP_ABBREVIATION,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=rt_experiment,
)
