from dataclasses import dataclass
from typing import Optional

from defs import BASE_MODEL_DIR
from lib_dl.analysis.experiment import ExperimentHandle
from lib_llm.training import OptimizerConfig, TrainingArguments

from .experiment import (
    ContextSearchConfig,
    ExperimentConfig,
    ModelConfig,
    RandomStringConfig,
    TrainingConfig,
    rcs_experiment,
)


PER_DEVICE_BATCH_SIZE = 16

NUM_EPOCHS = 50

SEEDS = {
    0: 5932,
    1: 4152,
    2: 4967,
    3: 2938,
    4: 84163,
}
MODELS = {
    "pyt-70m": "pythia-70m",
    "pyt-160m": "pythia-160m",
    "pyt-410m": "pythia-410m",
    "pyt-1b": "pythia-1b",
    "pyt-1.4b": "pythia-1.4b",
    "pyt-2.8b": "pythia-2.8b",
    "pyt-6.9b": "pythia-6.9b",
    "pyt-12b": "pythia-12b",
}


@dataclass
class ConfigArgs:
    model_id: str
    sequence_length: int
    num_sequences: int
    num_token_samples: int
    pretrained: bool = True

    @property
    def model_name(self):
        return MODELS[self.model_id]


SEQUENCE_LENGTH_ARGS = {
    f"{model_id}_sl-{seq_length}_ns-{num_sequences}": ConfigArgs(
        model_id=model_id,
        sequence_length=seq_length,
        num_sequences=num_sequences,
        num_token_samples=num_token_samples,
    )
    for model_id in ["pyt-1b"]
    for seq_length, num_token_samples in [
        (64, 32),
        (128, 100),
        (256, 100),
        (512, 100),
    ]
    for num_sequences in [1, 8]
}


CONFIG_ARGS = {
    "test": ConfigArgs(
        model_id="pyt-70m",
        sequence_length=64,
        num_sequences=8,
        num_token_samples=1,
    ),
} | SEQUENCE_LENGTH_ARGS


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    assert args.num_sequences <= PER_DEVICE_BATCH_SIZE
    steps_per_epoch = 1
    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        local_rank=-1,
        model=ModelConfig(
            model_id=args.model_name,
            base_dir=BASE_MODEL_DIR,
            pretrained=args.pretrained,
        ),
        data=RandomStringConfig(
            data_type="rand",
            characterwise_tokenization=True,
            seed=-1,
            num_sequences=args.num_sequences,
            num_tokens=args.sequence_length,
            alphabet="latin",
            alphabet_size=26,
        ),
        fine_tuning=TrainingConfig(
            seed=-1,
            args=TrainingArguments(
                deepspeed="src/experiments/memorization_hyperparam_rel/ds_memorization_config.json",
                num_train_epochs=NUM_EPOCHS,
                eval_steps=steps_per_epoch,
                save_strategy="no",
                # save_steps=10,
                save_total_limit=1,
            ),
            optimizer=OptimizerConfig(
                learning_rate=5e-6,
                schedule="lin",
                warmup=0.05,
            ),
            save_final_checkpoint=True,
            wandb_project_name="llms_cia",
        ),
        context_search=ContextSearchConfig(
            seed=-1,
            num_samples_per_sequence=args.num_token_samples,
            num_shuffle_samples_per_token=10,
        ),
    )
    assert (
        config.data.num_sequences
        <= config.fine_tuning.args.per_device_train_batch_size
    )
    if eval_type == "test":
        config.fine_tuning.args.num_train_epochs = 1
        config.context_search.num_shuffle_samples_per_token = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    seed = SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed = seed
    config.data.seed = seed
    config.fine_tuning.seed = seed
    config.context_search.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


RCSHandle = ExperimentHandle(
    id="rcs",
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=rcs_experiment,
)
