from dataclasses import dataclass
from typing import Optional

from defs import BASE_MODEL_DIR
from lib_dl.analysis.experiment import ExperimentHandle
from lib_llm.training import OptimizerConfig, TrainingArguments

from .experiment import (
    ContextProbingConfig,
    ExperimentConfig,
    ModelConfig,
    RandomStringConfig,
    TrainingConfig,
    cs_experiment,
)


PER_DEVICE_BATCH_SIZE = 16

NUM_SEQUENCES = 1
# SEQUENCE_LENGTH = 512
NUM_EPOCHS = 50

SEEDS = {
    0: 5932,
    1: 4152,
    2: 4967,
    3: 2938,
    4: 84163,
}
MODELS = {
    "pyt-70m": "pythia-70m",
    "pyt-160m": "pythia-160m",
    "pyt-410m": "pythia-410m",
    "pyt-1b": "pythia-1b",
    "pyt-1.4b": "pythia-1.4b",
    "pyt-2.8b": "pythia-2.8b",
    "pyt-6.9b": "pythia-6.9b",
    "pyt-12b": "pythia-12b",
}


@dataclass
class ConfigArgs:
    model_id: str
    prefix_length: int
    num_samples_per_string: int
    sequence_length: int
    pretrained: bool = True

    @property
    def model_name(self):
        return MODELS[self.model_id]


SEQUENCE_LENGTH_ARGS = {
    1024: ([512, 128, 32, 8, 2], 400),
    512: ([256, 128, 32, 8, 2], 200),
    256: ([128, 64, 32, 8, 2], 100),
}
PREFIX_LENGTH_ARGS = {
    f"{model_id}_sl-{seq_length}_pl-{prefix_length}": ConfigArgs(
        model_id=model_id,
        prefix_length=prefix_length,
        num_samples_per_string=num_samples,
        sequence_length=seq_length,
    )
    for model_id in MODELS.keys()
    for (
        seq_length,
        (prefix_lengths, num_samples),
    ) in SEQUENCE_LENGTH_ARGS.items()
    for prefix_length in prefix_lengths
}


CONFIG_ARGS = {
    "test": ConfigArgs(
        model_id="pyt-1b",
        prefix_length=256,
        num_samples_per_string=32,
        sequence_length=1024,
    ),
} | PREFIX_LENGTH_ARGS


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    steps_per_epoch = NUM_SEQUENCES
    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        local_rank=-1,
        model=ModelConfig(
            model_id=args.model_name,
            base_dir=BASE_MODEL_DIR,
            pretrained=args.pretrained,
        ),
        data=RandomStringConfig(
            data_type="rand",
            characterwise_tokenization=True,
            seed=-1,
            num_sequences=NUM_SEQUENCES,
            num_tokens=args.sequence_length,
            alphabet="latin",
            alphabet_size=26,
        ),
        fine_tuning=TrainingConfig(
            seed=-1,
            args=TrainingArguments(
                deepspeed="src/experiments/memorization_hyperparam_rel/ds_memorization_config.json",
                num_train_epochs=NUM_EPOCHS,
                eval_steps=steps_per_epoch,
            ),
            optimizer=OptimizerConfig(
                learning_rate=5e-6,
                schedule="lin",
                warmup=0.05,
            ),
            save_final_checkpoint=False,
            wandb_project_name="llms_cp",
        ),
        context_probing=ContextProbingConfig(
            seed=-1,
            prefix_length=args.prefix_length,
            num_samples_per_string=args.num_samples_per_string,
        ),
    )
    assert (
        config.data.num_sequences
        <= config.fine_tuning.args.per_device_train_batch_size
    )
    if eval_type == "test":
        config.fine_tuning.args.num_train_epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    seed = SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed = seed
    config.data.seed = seed
    config.fine_tuning.seed = seed
    config.context_probing.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


CSHandle = ExperimentHandle(
    id="cs",
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=cs_experiment,
)
