from dataclasses import dataclass, replace
from typing import Optional

from defs import EXPERIMENT_SEEDS, BaseConfigArgs
from lib_llm.models import ModelConfig, get_tokenizer_type
from lib_project.experiment import ExperimentHandle
from utils.memorization import get_memorization_training_config

from .experiment import (
    EXP_ABBREVIATION,
    ContextDataConfig,
    ExperimentConfig,
    RandomStringConfig,
    pmd_experiment,
)


@dataclass
class ConfigArgs(BaseConfigArgs):
    alphabet_size: int
    random_length: int
    context_type: str
    context_length: int
    batch_size: int
    num_epochs: int
    inject_every_n_steps: int = 0
    pretrained: bool = True

    # @property
    # def num_epochs(self) -> int:
    #     return 200 if self.pretrained else 400


MODEL_IDS = [
    # "pyt-70m",
    "pyt-1b",
    "llama2-13b",
    "phi-2.7b",
]
ALPHABET_SIZES = [2, 7, 26]
CONTEXT_TYPES = ["wiki", "rand_same_al", "rand_diff_al"]
BATCH_SIZE_ARGS = {
    (
        f"{model_id}"
        + ("_u" if not pretrained else "")
        + f"_a-{alphabet_size}_t-{num_random_tokens}_"
        + f"c-{context_type}_b-{batch_size}"
    ): ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        random_length=num_random_tokens,
        context_type=context_type,
        context_length=num_random_tokens,
        batch_size=batch_size,
        num_epochs=(200 if pretrained else 400),
        inject_every_n_steps=1,
        pretrained=pretrained,
    )
    for model_id in MODEL_IDS
    for pretrained in [True, False]
    for alphabet_size in ALPHABET_SIZES
    for num_random_tokens in [1024]
    for context_type in CONTEXT_TYPES
    for batch_size in [1, 4, 16, 64]
}
RELATIVE_CONTEXT_SIZE_ARGS = {
    (
        f"{model_id}"
        + ("_u" if not pretrained else "")
        + f"_a-{alphabet_size}_t-{num_random_tokens}_"
        + f"c-{context_type}_x-{relative_context_size}"
    ): ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        random_length=num_random_tokens,
        context_type=context_type,
        context_length=num_random_tokens * relative_context_size,
        batch_size=1,
        num_epochs=(100 if pretrained else 200),
        inject_every_n_steps=1,
        pretrained=pretrained,
    )
    for model_id in MODEL_IDS
    for pretrained in [True, False]
    for alphabet_size in ALPHABET_SIZES
    for num_random_tokens in [256]
    for context_type in CONTEXT_TYPES
    for relative_context_size in [1, 2, 4, 8]
}
INJECTION_ARGS = {
    (
        f"{model_id}"
        + ("_u" if not pretrained else "")
        + f"_a-{alphabet_size}_t-{num_random_tokens}_"
        + f"c-{context_type}_i-{inject_every_n_steps}"
    ): ConfigArgs(
        model_id=model_id,
        alphabet_size=alphabet_size,
        random_length=num_random_tokens,
        context_type=context_type,
        context_length=num_random_tokens,
        batch_size=1,
        num_epochs=(200 if pretrained else 400),
        inject_every_n_steps=inject_every_n_steps,
        pretrained=pretrained,
    )
    for model_id in MODEL_IDS
    for pretrained in [True, False]
    for alphabet_size in ALPHABET_SIZES
    for num_random_tokens in [1024]
    for context_type in CONTEXT_TYPES
    for inject_every_n_steps in [1, 2, 4, 8]
}


CONFIG_ARGS = {
    "test": ConfigArgs(
        # model_id="pyt-70m",
        # model_id="pyt-1b",
        # model_id="llama2-7b",
        model_id="llama2-13b",
        random_length=32,
        context_type="wiki",
        context_length=32,
        batch_size=1,
        num_epochs=1,
        alphabet_size=26,
        inject_every_n_steps=1,
    ),
} | (BATCH_SIZE_ARGS | RELATIVE_CONTEXT_SIZE_ARGS | INJECTION_ARGS)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    model_config = ModelConfig(
        model_id=args.model_name,
        base_dir=args.model_dir,
        pretrained=args.pretrained,
    )
    training_config = get_memorization_training_config(
        seed_id=-1,
        model_id=model_config.model_id_not_none,
        num_epochs=args.num_epochs,
        batch_size=args.batch_size,
        wandb_project_name=f"llm_mem_{EXP_ABBREVIATION}",
        save_final_checkpoint=False,
        is_pretrained=args.pretrained,
    )
    training_config.args = replace(
        training_config.args,
        evaluation_strategy="steps",
        eval_steps=1,
        # Turn off WandB logging, since it's causing connection issues
        # in some runs
        report_to=["none"],
    )

    if args.context_type == "wiki":
        context_dataset = "wikitext"
        context_variant = "wikitext-103-raw-v1"
    elif args.context_type.startswith("rand"):
        context_dataset = "random"
        if args.context_type == "rand_same_al":
            context_variant = "same_alphabet"
        elif args.context_type == "rand_diff_al":
            context_variant = "diff_alphabet"
        else:
            raise ValueError(f"Invalid context type: {args.context_type}")
    else:
        raise ValueError(f"Invalid context type: {args.context_type}")

    config = ExperimentConfig(
        name=eval_type,
        seed_id=-1,
        seed=-1,
        random_data=RandomStringConfig(
            seed_id=-1,
            alphabet="latin",
            alphabet_size=args.alphabet_size,
            num_tokens=args.random_length,
            num_partitions=1,
            tokenizer_type=get_tokenizer_type(args.model_name),
        ),
        context_data=ContextDataConfig(
            seed=-1,
            dataset=context_dataset,
            dataset_variant=context_variant,
            sequence_length=args.context_length,
            batch_size=args.batch_size,
        ),
        model=model_config,
        training=training_config,
    )
    if eval_type == "test":
        config.training.args.num_train_epochs = 1
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    config.seed_id = seed_id
    seed = EXPERIMENT_SEEDS[seed_id]

    config.random_data.seed_id = seed_id
    config.context_data.seed = seed
    config.training.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


PMDHandle = ExperimentHandle(
    id=EXP_ABBREVIATION,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=pmd_experiment,
)
