from dataclasses import dataclass, field
from typing import Optional

from defs import EXPERIMENT_SEEDS, BaseConfigArgs
from lib_llm.models import ModelConfig, get_tokenizer_type
from lib_project.experiment import ExperimentHandle
from utils.memorization import get_memorization_training_config

from .experiment import (
    EXP_ABBREVIATION,
    ConditionalRandomStringConfig,
    DeterministicRuleStringConfig,
    ExperimentConfig,
    MemorizationConfig,
    PrefixEvalConfig,
    RandomStringConfig,
    pm_experiment,
)


@dataclass
class ConfigArgs(BaseConfigArgs):
    config_group: str
    num_tokens: int
    num_strings: int = 1
    premise_length: int = -1
    num_epochs: int = 100
    pretrained: bool = True
    alphabet_size: int = 26
    entropy_target: int | None = None
    dense_prefix_sampling_end: int = 21
    size_change: float = 1.0
    replacement_strategy: str = "rand_id"
    eval_epochs: list[int] = field(default_factory=lambda: [-1])
    relative_probability: int = 1
    ngram_length: int = 0


SEQUENCE_LENGTH_ALPHABET_SIZE_ARGS = {
    f"{model_id}_a-{alphabet_size}_t-{num_tokens}": ConfigArgs(
        config_group="intermediate_eval",
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=num_tokens,
        num_epochs=100,
        eval_epochs=[5, 10, 15, 20, 30, 40, 50, 100],
    )
    for model_id in [
        "pyt-70m",
        "pyt-1b",
        "pyt-12b",
        "llama2-7b",
        "llama2-13b",
        "gpt2-124m",
        "gpt2-1.5b",
        "phi-2.7b",
        "opt-350m",
    ]
    for num_tokens in [16, 32, 64, 128, 256, 512, 1024]
    for alphabet_size in [2, 4, 7, 13, 26]
}
ENTROPY_CONTROL_ARGS = {
    f"{model_id}_h-{entropy_target}_t-{num_tokens}": ConfigArgs(
        config_group="intermediate_eval",
        model_id=model_id,
        alphabet_size=26,
        num_tokens=1024,
        entropy_target=entropy_target,
        num_epochs=100,
        eval_epochs=[5, 10, 15, 20, 30, 40, 50, 100],
    )
    for model_id in [
        "pyt-70m",
        "pyt-1b",
        "pyt-12b",
        "llama2-7b",
        "llama2-13b",
        # "gpt2-124m",
        # "gpt2-1.5b",
        "phi-2.7b",
    ]
    for num_tokens in [1024]
    for entropy_target in [13, 7, 4, 2]
}
SIZE_CHANGE_ARGS = {
    f"{model_id}_t-{num_tokens}_sc-{size_change}": ConfigArgs(
        config_group="size_change",
        model_id=model_id,
        num_tokens=num_tokens,
        size_change=size_change,
    )
    for model_id in [
        "pyt-1b",
        "gpt2-124m",
        "phi-2.7b",
        "llama2-13b",
        "opt-350m",
    ]
    for num_tokens in [
        128,
        256,
        512,
        1024,
    ]  # 1024 with >1 is too large for GPT2
    for size_change in [0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]
}
REPLACEMENT_STRATEGY_ARGS = {
    f"{model_id}_t-{num_tokens}_rs-{replacement_strategy}": ConfigArgs(
        config_group="replacement_strategy",
        model_id=model_id,
        num_tokens=num_tokens,
        replacement_strategy=replacement_strategy,
    )
    for model_id in [
        "pyt-1b",
        "gpt2-124m",
        "phi-2.7b",
        "llama2-13b",
        "opt-350m",
    ]
    for num_tokens in [128, 512, 1024]
    for replacement_strategy in [
        "rand_id",
        "const_id",
        # "rand_ood",
        # "const_ood",
    ]
}
UNTRAINED_ARGS = {
    f"{model_id}-ut_a-{alphabet_size}": ConfigArgs(
        config_group="intermediate_eval",
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=1024,
        pretrained=False,
        num_epochs=300,
        eval_epochs=[10, 20, 40, 80, 160, 300],
    )
    for model_id in [
        "pyt-1b",
        "llama2-13b",
        "phi-2.7b",
    ]
    for alphabet_size in [2, 4, 7, 13, 26]
}
CONDITIONAL_PROB_ARGS = {
    (
        f"{model_id}"
        + ("_u" if not pretrained else "")
        + f"_a-{alphabet_size}"
        + f"_rp-{relative_probability}_n-{ngram_length}"
    ): ConfigArgs(
        config_group="conditional_probability",
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=1024,
        relative_probability=relative_probability,
        ngram_length=ngram_length,
        eval_epochs=[5, 10, 15, 20, 30, 40, 50, 100],
    )
    for model_id in [
        "pyt-1b",
        "llama2-13b",
        "phi-2.7b",
    ]
    for pretrained in [True]
    for alphabet_size in [2, 7, 26]
    for relative_probability in [16]
    for ngram_length in [1, 2, 3, 4]
}
DETERMINISTIC_RULES_ARGS = {
    f"{model_id}_det_a-{alphabet_size}_l-{num_tokens}_n-{num_strings}_prem-{premise_length}": ConfigArgs(
        config_group="deterministic_rules",
        model_id=model_id,
        alphabet_size=alphabet_size,
        num_tokens=num_tokens,
        num_strings=num_strings,
        premise_length=premise_length,
        eval_epochs=[0, 5, 10, 15, 20, 25, 30, 40, 100],
        dense_prefix_sampling_end=max(5, premise_length + 1),
    )
    for model_id in ["pyt-1b"]
    for alphabet_size in [2, 4, 26]
    for num_tokens in [16, 32, 64, 1024]
    for num_strings in [1, 2, 4, 8]
    for premise_length in [1, 2, 3, 4]
}


CONFIG_ARGS = {
    "test": ConfigArgs(
        config_group="test",
        model_id="pyt-70m",
        num_tokens=16,
        num_epochs=1,
        ngram_length=1,
        relative_probability=16,
        eval_epochs=[0, 1],
    ),
} | (
    SEQUENCE_LENGTH_ALPHABET_SIZE_ARGS
    | ENTROPY_CONTROL_ARGS
    | SIZE_CHANGE_ARGS
    | REPLACEMENT_STRATEGY_ARGS
    | UNTRAINED_ARGS
    | CONDITIONAL_PROB_ARGS
    | DETERMINISTIC_RULES_ARGS
)


def create_config(
    eval_type: str,
    seed_id: Optional[int] = None,
) -> ExperimentConfig:
    args = CONFIG_ARGS[eval_type]

    model_config = ModelConfig(
        model_id=args.model_name,
        base_dir=args.model_dir,
        pretrained=args.pretrained,
    )
    if args.premise_length == -1:
        if args.relative_probability != 1:
            assert args.ngram_length > 0
            random_data = ConditionalRandomStringConfig(
                seed_id=-1,
                num_tokens=args.num_tokens,
                alphabet_size=args.alphabet_size,
                tokenizer_type=get_tokenizer_type(args.model_name),
                ngram_length=args.ngram_length,
            )
            random_data.set_relative_probability(args.relative_probability)
        else:
            random_data = RandomStringConfig(
                seed_id=-1,
                num_tokens=args.num_tokens,
                alphabet_size=args.alphabet_size,
                entropy_like=args.entropy_target,
                tokenizer_type=get_tokenizer_type(args.model_name),
            )
        deterministic_data = None
    else:
        random_data = None
        deterministic_data = DeterministicRuleStringConfig(
            seed_id=-1,
            string_length=args.num_tokens,
            num_strings=args.num_strings,
            premise_length=args.premise_length,
            alphabet_size=args.alphabet_size,
            tokenizer_type=get_tokenizer_type(args.model_name),
        )
    config = ExperimentConfig(
        group=args.config_group,
        name=eval_type,
        seed_id=-1,
        seed=-1,
        random_data=random_data,
        deterministic_rule_data=deterministic_data,
        model=model_config,
        memorization=MemorizationConfig(
            training=get_memorization_training_config(
                seed_id=-1,
                model_id=model_config.model_id_not_none,
                num_epochs=args.num_epochs,
                batch_size=1,
                wandb_project_name=f"llm_mem_{EXP_ABBREVIATION}",
                use_wandb=False,
            ),
        ),
        prefix_eval_epochs=args.eval_epochs,
        prefix_testing=PrefixEvalConfig(
            seed=-1,
            num_samples_per_prefix=(20 if args.num_tokens <= 128 else 10),
            relative_context_size=args.size_change,
            max_token_samples=256,
        ),
        replacement_strategy=args.replacement_strategy,
    )
    if eval_type == "test":
        config.memorization.training.args.num_train_epochs = 1
        config.prefix_testing.num_samples_per_prefix = 2
    if seed_id is not None:
        set_seeds(config, seed_id)
    return config


def set_seeds(
    config: ExperimentConfig,
    seed_id: int,
) -> ExperimentConfig:
    seed = EXPERIMENT_SEEDS[seed_id]
    config.seed_id = seed_id
    config.seed = seed

    if config.random_data is not None:
        config.random_data.seed_id = seed_id
    if config.deterministic_rule_data is not None:
        config.deterministic_rule_data.seed_id = seed_id
    config.memorization.training.seed = seed
    config.prefix_testing.seed = seed
    return config


def get_configs() -> list[ExperimentConfig]:
    configs = []
    for eval_type in CONFIG_ARGS.keys():
        config = create_config(eval_type)
        configs.append(config)
    return configs


PMHandle = ExperimentHandle(
    id=EXP_ABBREVIATION,
    create_configs=get_configs,
    set_seed=set_seeds,
    experiment=pm_experiment,
)
