from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from lib_dl.analysis.experiment import ExperimentHandle
from lib_llm.training.fine_tune import OptimizerConfig, TrainingArguments

from .experiment import (
    SDDDExperimentConfig,
    ModelConfig,
    RandomStringConfig,
    TrainingConfig,
    SDDD_experiment,
)


PROJECT_BASE_DIR = Path(__file__).parent.parent.parent.parent.resolve()
BASE_MODEL_DIR = PROJECT_BASE_DIR / "base_models"

DEFAULT_NUM_SEQUENCES = 1
SEQUENCE_LENGTH = 128
PER_DEVICE_BATCH_SIZE = 16

SEEDS = {
    0: 5932,
    1: 4152,
    2: 4967,
    3: 2938,
    4: 84163,
    5: 42,
    6: 1242,
    7: 131,
    8: 910,
    9: 107
}

MODELS = {
    "pyt-70m": "pythia-70m",
    "pyt-160m": "pythia-160m",
    "pyt-410m": "pythia-410m",
    "pyt-1b": "pythia-1b",
    "pyt-1.4b": "pythia-1.4b",
    "pyt-2.8b": "pythia-2.8b",
    "pyt-6.9b": "pythia-6.9b",
    "pyt-12b": "pythia-12b",
}


@dataclass
class ConfigArgs:
    model_id: str
    pretrained: bool = True
    learning_rate: float = 5e-6
    lr_schedule: str = "lin"
    warmup: float = 0.05
    data_type: str = "rand"
    num_sequences: int = 64
    sequence_length: int = 128
    characterwise_tokenization: bool = True
    alphabet: str = "latin"
    alphabet_size: int = 26
    shifted_pos_token_experiment: bool = False
    eval_1024: bool = False


    @property
    def model_name(self):
        return MODELS[self.model_id]


# Training hyperparamter exploration configs
NUM_SEQUENCES_VARIATION_ARGS = {
    f"{model_id}_ns-{num_sequences}_lr-{learning_rate}": ConfigArgs(
        model_id=model_id,
        learning_rate=learning_rate,
        warmup=10,
        num_sequences=num_sequences,
    )
    for model_id in ["pyt-70m", "pyt-1b"]
    for num_sequences in [1, 16, 256]
    for learning_rate in [1e-4, 1e-5, 5e-6, 1e-6, 5e-7]
}
LR_SCHEDULE_VARIATION_ARGS = {
    f"{model_id}_lrs-{schedule}_lr-{learning_rate}": ConfigArgs(
        model_id,
        learning_rate=learning_rate,
        lr_schedule=schedule,
        warmup=10,
        num_sequences=64,
    )
    for model_id in ["pyt-70m", "pyt-1b"]
    for learning_rate in [1e-4, 1e-5, 1e-6]
    for schedule in ["const", "lin", "cos"]
}
WARMUP_STEPS_VARIATION_ARGS = {
    f"{model_id}_ws-{warmup_steps}_lr-{learning_rate}": ConfigArgs(
        model_id,
        learning_rate=learning_rate,
        lr_schedule="lin",
        warmup=warmup_steps,
        num_sequences=64,
    )
    for model_id in ["pyt-70m", "pyt-1b"]
    for learning_rate in [1e-5, 1e-6]
    for warmup_steps in [0, 10, 50, 100]
}

# Data exploration configs
SEQUENCE_LENGTH_ARGS = {
    f"{model_id}_sl-{sequence_length}_ns-{num_sequences}": ConfigArgs(
        model_id=model_id,
        num_sequences=64,
        sequence_length=sequence_length,
    )
    for model_id in ["pyt-70m", "pyt-1b", "pyt-12b"]
    for sequence_length in [4, 16, 64, 128, 256, 1024]
    for num_sequences in [1, 16, 64]
}
# Random vs natual vs semi-random (previous names data)
NUM_SEQUENCE_ARGS = {
    f"{model_id}_ns-{num_sequences}": ConfigArgs(
        model_id=model_id,
        num_sequences=num_sequences,
        # sequence_length=sequence_length,
    )
    for model_id in ["pyt-70m", "pyt-1b"]
    for num_sequences in [1, 16, 256]
}
# Data type and default vs character-wise tokenization
DATA_TYPE_ARGS = {
    f"{model_id}_dt-{data_type}_tok-{tokenization_id}": ConfigArgs(
        model_id=model_id,
        data_type=data_type,
        num_sequences=64,
        sequence_length=20,
        characterwise_tokenization=characterwise_tokenization,
    )
    for model_id in ["pyt-70m", "pyt-1b"]
    for data_type in ["rand", "rand-names", "sci-names", "wiki"]
    for tokenization_id, characterwise_tokenization in [
        # Character-wise and default tokenization
        ("char", True),
        ("def", False),
    ]
}
ALPHABET_ARGS = {
    f"{model_id}_alph-{alphabet}-{alphabet_size}": ConfigArgs(
        model_id=model_id,
        num_sequences=1,
        sequence_length=512,
        alphabet=alphabet,
        alphabet_size=alphabet_size,
    )
    for model_id in ["pyt-70m", "pyt-1b"]
    for alphabet in ["latin"]
    for alphabet_size in [2, 13, 26]
}

# # Model exploration configs
# Model size
# Pretrained vs untrained

CONFIG_ARGS = (
    {
        "test": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=64,
            sequence_length=16,
            characterwise_tokenization=True,
        ),
        "ns-64_sl-16_1b_shifted_pos_tok": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=64,
            sequence_length=16,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-32_sl-32_1b_shifted_pos_tok": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=32,
            sequence_length=32,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-16_sl-64_1b_shifted_pos_tok": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=16,
            sequence_length=64,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-8_sl-128_1b_shifted_pos_tok": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=8,
            sequence_length=128,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-4_sl-256_1b_shifted_pos_tok": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=4,
            sequence_length=256,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-2_sl-512_1b_shifted_pos_tok": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=2,
            sequence_length=512,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-1_sl-1024_1b_shifted_pos_tok": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=1,
            sequence_length=1024,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-64_sl-16_70m_shifted_pos_tok": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=64,
            sequence_length=16,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-32_sl-32_70m_shifted_pos_tok": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=32,
            sequence_length=32,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-16_sl-64_70m_shifted_pos_tok": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=16,
            sequence_length=64,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-8_sl-128_70m_shifted_pos_tok": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=8,
            sequence_length=128,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-4_sl-256_70m_shifted_pos_tok": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=4,
            sequence_length=256,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-2_sl-512_70m_shifted_pos_tok": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=2,
            sequence_length=512,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-1_sl-1024_70m_shifted_pos_tok": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=1,
            sequence_length=1024,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=True,
        ),
        "ns-64_sl-16_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=64,
            sequence_length=16,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-32_sl-32_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=32,
            sequence_length=32,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-16_sl-64_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=16,
            sequence_length=64,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-8_sl-128_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=8,
            sequence_length=128,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-4_sl-256_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=4,
            sequence_length=256,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-2_sl-512_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=2,
            sequence_length=512,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-1_sl-1024_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=1,
            sequence_length=1024,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-64_sl-16_70m": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=64,
            sequence_length=16,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-32_sl-32_70m": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=32,
            sequence_length=32,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-16_sl-64_70m": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=16,
            sequence_length=64,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-8_sl-128_70m": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=8,
            sequence_length=128,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-4_sl-256_70m": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=4,
            sequence_length=256,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-2_sl-512_70m": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=2,
            sequence_length=512,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "ns-1_sl-1024_70m": ConfigArgs(
            model_id="pyt-70m",
            data_type="rand",
            pretrained=True,
            num_sequences=1,
            sequence_length=1024,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
        ),
        "pt_ns-2_pt_sl-512_curr_ns-1_curr_sl-1024_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=2,
            sequence_length=512,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
            eval_1024=True,
        ),
        "pt_ns-4_pt_sl-256_curr_ns-1_curr_sl-1024_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=4,
            sequence_length=256,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
            eval_1024=True,
        ),
        "pt_ns-8_pt_sl-126_curr_ns-1_curr_sl-1024_1b": ConfigArgs(
            model_id="pyt-1b",
            data_type="rand",
            pretrained=True,
            num_sequences=8,
            sequence_length=128,
            characterwise_tokenization=True,
            shifted_pos_token_experiment=False,
            eval_1024=True,
        ),
    }
    | NUM_SEQUENCES_VARIATION_ARGS
    | LR_SCHEDULE_VARIATION_ARGS
    | WARMUP_STEPS_VARIATION_ARGS
    | SEQUENCE_LENGTH_ARGS
    | DATA_TYPE_ARGS
    | NUM_SEQUENCE_ARGS
    | ALPHABET_ARGS
)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> SDDDExperimentConfig:
    args = CONFIG_ARGS[eval_type]
    PER_DEVICE_BATCH_SIZE = args.num_sequences
    steps_per_epoch = max(1, int(args.num_sequences / PER_DEVICE_BATCH_SIZE))
    config = SDDDExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        shifted_pos_token_experiment=args.shifted_pos_token_experiment,
        eval_1024=args.eval_1024,
        local_rank=-1,
        base_model=ModelConfig(
            model_id=args.model_name,
            base_dir=BASE_MODEL_DIR,
            pretrained=args.pretrained,
        ),
        data=RandomStringConfig(
            data_type=args.data_type,
            characterwise_tokenization=args.characterwise_tokenization,
            seed=-1,
            num_sequences=args.num_sequences,
            sequence_length=args.sequence_length,
            alphabet=args.alphabet,
            alphabet_size=args.alphabet_size,
        ),
        fine_tuning=TrainingConfig(
            seed=-1,
            args=TrainingArguments(
                deepspeed="src/experiments/memorization_hyperparam_rel/ds_memorization_config.json",
                num_train_epochs=args.num_epochs,
                eval_steps=steps_per_epoch,
                per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
                logging_steps=5,
                save_strategy="no",
            ),
            optimizer=OptimizerConfig(
                learning_rate=args.learning_rate,
                schedule=args.lr_schedule,
                warmup=args.warmup,
            ),
            save_final_checkpoint=False,
            wandb_project_name="sddd",
        ),

    )
    if eval_type == "test":
        # config.base_model.model_id = "EleutherAI/pythia-70m"
        # config.base_model.base_dir = None
        # config.base_model = "EleutherAI/pythia-1.4b"
        config.fine_tuning.num_train_epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: SDDDExperimentConfig,
    seed_id: int,
) -> SDDDExperimentConfig:
    seed = SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed = seed
    config.data.seed = seed
    config.fine_tuning.seed = seed
    return config


def get_configs() -> list[SDDDExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


SDDDHandle = ExperimentHandle(
    id="sddd",
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=SDDD_experiment,
)
