from dataclasses import dataclass
from typing import Optional

from defs import BaseConfigArgs
from lib_llm.models import ModelConfig, get_tokenizer_type
from lib_llm.training import TrainingArguments
from lib_project.experiment import ExperimentHandle
from utils.memorization import MEMORIZATION_DEEPSPEED_CONFIG

from .experiment import (
    EXP_ABBREVIATION,
    ExperimentConfig,
    RandomStringConfig,
    TrainingConfig,
    hs_experiment,
)


DEFAULT_NUM_SEQUENCES = 1
SEQUENCE_LENGTH = 128
PER_DEVICE_BATCH_SIZE = 16

SEEDS = {
    0: 5932,
    1: 4152,
    2: 4967,
    3: 2938,
    4: 84163,
}


@dataclass
class ConfigArgs(BaseConfigArgs):
    config_group: str
    pretrained: bool = True
    learning_rate: float = 5e-6
    lr_schedule: str = "lin"
    warmup_steps: int = 0
    weight_decay: float = 0
    num_epochs: int = 100
    data_type: str = "rand"
    num_tokens: int = 1024
    num_partitions: int = 1
    alphabet: str = "latin"
    alphabet_size: int = 26

    @property
    def learning_rate_schedule(self) -> str:
        if self.lr_schedule == "const":
            return "constant_with_warmup"
        elif self.lr_schedule == "lin":
            return "linear"
        elif self.lr_schedule == "cos":
            return "cosine"
        else:
            raise ValueError(
                f"Unknown learning rate schedule {self.lr_schedule}"
            )


# Training hyperparamter exploration configs
LR_SCHEDULE_VARIATION_ARGS = {
    f"{model_id}_lrs-{schedule}_lr-{learning_rate:.0e}": ConfigArgs(
        config_group="lr_schedule",
        model_id=model_id,
        learning_rate=learning_rate,
        lr_schedule=schedule,
    )
    for model_id in [
        "pyt-70m",
        "pyt-160m",
        "pyt-410m",
        "pyt-1b",
        "pyt-1.4b",
        "pyt-2.8b",
        "pyt-12b",
        "gpt2-124m",
        "gpt2-medium",
        "gpt2-large",
        "gpt2-1.5b",
        "llama2-7b",
        # "phi-1",
        "phi-1.3b",
        "phi-2.7b",
        "opt-350m",
    ]
    for learning_rate in [
        1e-02,
        1e-03,
        5e-04,
        1e-04,
        5e-05,
        1e-05,
        5e-06,
        1e-06,
        5e-07,
    ]
    for schedule in ["const", "lin", "cos"]
}
WARMUP_STEPS_VARIATION_ARGS = {
    f"{model_id}_ws-{warmup_steps}_lr-{learning_rate:.0e}": ConfigArgs(
        config_group="warmup_steps",
        model_id=model_id,
        learning_rate=learning_rate,
        warmup_steps=warmup_steps,
    )
    for model_id in [
        "pyt-70m",
        "pyt-1b",
        "pyt-1.4b",
    ]
    for learning_rate in [1e-03, 1e-04, 1e-05, 5e-06, 1e-06, 5e-07]
    for warmup_steps in [0, 5, 10, 20, 50, 100]
}
ALPHABET_SIZE_LR_ARGS = {
    f"{model_id}_a-{alphabet_size}_lr-{learning_rate:.0e}": ConfigArgs(
        config_group="lr_schedule",
        model_id=model_id,
        learning_rate=learning_rate,
        alphabet_size=alphabet_size,
    )
    for model_id in [
        "pyt-70m",
        "pyt-160m",
        "pyt-410m",
        "pyt-1b",
        "pyt-1.4b",
        "pyt-2.8b",
        "pyt-12b",
        "gpt2-124m",
        "gpt2-medium",
        "gpt2-large",
        "gpt2-1.5b",
        "llama2-7b",
        "phi-1.3b",
        "phi-2.7b",
    ]
    for alphabet_size in [
        2,
        # 4,
        # 7,
        # 13,
        26,
    ]
    for learning_rate in [
        1e-02,
        1e-03,
        5e-04,
        1e-04,
        5e-05,
        1e-05,
        5e-06,
        1e-06,
        5e-07,
    ]
}
UNTRAINED_MODEL_ARGS = {
    f"{model_id}_ut_lr-{learning_rate:.0e}": ConfigArgs(
        config_group="untrained",
        model_id=model_id,
        learning_rate=learning_rate,
        pretrained=False,
        num_epochs=500,
    )
    for model_id in [
        "pyt-70m",
        "pyt-1b",
        "phi-1.3b",
        "phi-2.7b",
        "llama2-13b",
    ]
    for learning_rate in [
        1e-03,
        5e-04,
        1e-04,
        5e-05,
        1e-05,
        5e-06,
        1e-06,
    ]
}
WEIGHT_DECAY_ARGS = {
    f"{model_id}_wd-{weight_decay}": ConfigArgs(
        config_group="weight_decay",
        model_id=model_id,
        weight_decay=weight_decay,
    )
    for model_id in [
        "pyt-70m",
        "pyt-1b",
        "pyt-1.4b",
    ]
    for weight_decay in [0, 0.01, 0.05, 0.1, 0.2, 0.5]
}

# Data type and default vs character-wise tokenization
# DATA_TYPE_ARGS = {
#     f"{model_id}_dt-{data_type}_tok-{tokenization_id}": ConfigArgs(
#         config_group="data_type",
#         model_id=model_id,
#         data_type=data_type,
#         num_partitions=64 * 20,
#         num_tokens=20,
#         characterwise_tokenization=characterwise_tokenization,
#     )
#     for model_id in ["pyt-70m", "pyt-1b"]
#     for data_type in ["rand", "rand-names", "sci-names", "wiki"]
#     for tokenization_id, characterwise_tokenization in [
#         # Character-wise and default tokenization
#         ("char", True),
#         ("def", False),
#     ]
# }

# # Model exploration configs
# Model size
# Pretrained vs untrained

CONFIG_ARGS = (
    {
        "test": ConfigArgs(
            config_group="test",
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_partitions=1,
            num_tokens=128,
        ),
    }
    | LR_SCHEDULE_VARIATION_ARGS
    | WARMUP_STEPS_VARIATION_ARGS
    | ALPHABET_SIZE_LR_ARGS
    | UNTRAINED_MODEL_ARGS
    | WEIGHT_DECAY_ARGS
    # | DATA_TYPE_ARGS
)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    steps_per_epoch = max(1, int(args.num_partitions / PER_DEVICE_BATCH_SIZE))
    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        model=ModelConfig(
            model_id=args.model_name,
            base_dir=args.model_dir,
            pretrained=args.pretrained,
        ),
        data=RandomStringConfig(
            seed_id=-1,
            num_tokens=args.num_tokens,
            num_partitions=args.num_partitions,
            alphabet_size=args.alphabet_size,
            tokenizer_type=get_tokenizer_type(args.model_name),
        ),
        training=TrainingConfig(
            seed=-1,
            args=TrainingArguments(
                learning_rate=args.learning_rate,
                lr_scheduler_type=args.learning_rate_schedule,
                warmup_steps=args.warmup_steps,
                deepspeed=MEMORIZATION_DEEPSPEED_CONFIG,
                num_train_epochs=args.num_epochs,
                evaluation_strategy="epoch",
                eval_steps=steps_per_epoch,
                logging_steps=10,
                save_strategy="no",
                weight_decay=args.weight_decay,
                per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
                per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
            ),
            save_final_checkpoint=False,
            wandb_project_name=f"llms_{EXP_ABBREVIATION}",
        ),
    )
    if eval_type == "test":
        # config.base_model.model_id = "EleutherAI/pythia-70m"
        # config.base_model.base_dir = None
        # config.base_model = "EleutherAI/pythia-1.4b"
        config.training.args.num_train_epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    seed = SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed = seed
    config.data.seed_id = seed_id
    config.training.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


HSHandle = ExperimentHandle(
    id=EXP_ABBREVIATION,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=hs_experiment,
)
