from dataclasses import dataclass
from typing import Optional

from defs import ARTIFACTS_DIR, BaseConfigArgs
from lib_dl.analysis.experiment import ExperimentHandle
from lib_llm.training import OptimizerConfig, TrainingArguments

from .experiment import (
    ExperimentConfig,
    ModelConfig,
    RandomStringConfig,
    TrainingConfig,
    mhr_experiment,
)


DEFAULT_NUM_SEQUENCES = 1
SEQUENCE_LENGTH = 128
PER_DEVICE_BATCH_SIZE = 16

SEEDS = {
    0: 5932,
    1: 4152,
    2: 4967,
    3: 2938,
    4: 84163,
}


@dataclass
class ConfigArgs(BaseConfigArgs):
    pretrained: bool = True
    learning_rate: float = 5e-6
    lr_schedule: str = "lin"
    warmup: float = 0.05
    num_epochs: int = 50
    data_type: str = "rand"
    num_tokens: int = 512
    num_partitions: int = 1
    characterwise_tokenization: bool = True
    alphabet: str = "latin"
    alphabet_size: int = 26


# Training hyperparamter exploration configs
LR_SCHEDULE_VARIATION_ARGS = {
    f"{model_id}_lrs-{schedule}_lr-{learning_rate:.0e}": ConfigArgs(
        model_id,
        learning_rate=learning_rate,
        lr_schedule=schedule,
    )
    for model_id in [
        "pyt-70m",
        "pyt-160m",
        "pyt-410m",
        "pyt-1b",
        "gpt2",
        "gpt2-xl",
    ]
    for learning_rate in [1e-03, 1e-04, 1e-05, 5e-06, 1e-06, 5e-07]
    for schedule in ["const", "lin", "cos"]
}
WARMUP_STEPS_VARIATION_ARGS = {
    f"{model_id}_ws-{warmup_steps}_lr-{learning_rate:.0e}": ConfigArgs(
        model_id,
        learning_rate=learning_rate,
        warmup=warmup_steps,
    )
    for model_id in ["pyt-70m", "pyt-1b"]
    for learning_rate in [1e-03, 1e-04, 1e-05, 5e-06, 1e-06, 5e-07]
    for warmup_steps in [0, 10, 50, 100]
}

# Data type and default vs character-wise tokenization
DATA_TYPE_ARGS = {
    f"{model_id}_dt-{data_type}_tok-{tokenization_id}": ConfigArgs(
        model_id=model_id,
        data_type=data_type,
        num_partitions=64 * 20,
        num_tokens=20,
        characterwise_tokenization=characterwise_tokenization,
    )
    for model_id in ["pyt-70m", "pyt-1b"]
    for data_type in ["rand", "rand-names", "sci-names", "wiki"]
    for tokenization_id, characterwise_tokenization in [
        # Character-wise and default tokenization
        ("char", True),
        ("def", False),
    ]
}

# # Model exploration configs
# Model size
# Pretrained vs untrained

CONFIG_ARGS = (
    {
        "test": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_partitions=1,
            num_tokens=128,
            characterwise_tokenization=True,
        ),
    }
    | LR_SCHEDULE_VARIATION_ARGS
    | WARMUP_STEPS_VARIATION_ARGS
    | DATA_TYPE_ARGS
)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    steps_per_epoch = max(1, int(args.num_partitions / PER_DEVICE_BATCH_SIZE))
    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        local_rank=-1,
        model=ModelConfig(
            model_id=args.model_name,
            base_dir=args.model_dir,
            pretrained=args.pretrained,
        ),
        data=RandomStringConfig(
            seed_id=-1,
            num_tokens=args.num_tokens,
            num_partitions=args.num_partitions,
            alphabet_size=args.alphabet_size,
            artifacts_dir=ARTIFACTS_DIR,
        ),
        characterwise_tokenization=args.characterwise_tokenization,
        fine_tuning=TrainingConfig(
            seed=-1,
            args=TrainingArguments(
                deepspeed="src/experiments/memorization_hyperparam_rel/ds_memorization_config.json",
                num_train_epochs=args.num_epochs,
                eval_steps=steps_per_epoch,
                per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
                logging_steps=1,
                save_strategy="no",
            ),
            optimizer=OptimizerConfig(
                learning_rate=args.learning_rate,
                schedule=args.lr_schedule,
                warmup=args.warmup,
            ),
            save_final_checkpoint=False,
            wandb_project_name="llms_mhr",
        ),
    )
    if eval_type == "test":
        # config.base_model.model_id = "EleutherAI/pythia-70m"
        # config.base_model.base_dir = None
        # config.base_model = "EleutherAI/pythia-1.4b"
        config.fine_tuning.args.num_train_epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    seed = SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed = seed
    config.data.seed_id = seed_id
    config.fine_tuning.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


MHRHandle = ExperimentHandle(
    id="mhr",
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=mhr_experiment,
)
