import logging
from dataclasses import dataclass, field

import pandas as pd
import torch

from data.synthetic_strings import deunion_data, get_config, load_data
from data.synthetic_strings.conditional_random import (
    ConditionalRandomStringConfig,
    create_ngram_conditional_dataset,
)
from data.synthetic_strings.deterministic_rules import (
    DeterministicRuleStringConfig,
    DeterministicRuleStringData,
)
from data.synthetic_strings.random import RandomStringConfig, RandomStringData
from defs import LLMExperimentConfig
from experiments.memorization_dynamics.experiment import (
    ExperimentConfig as MemorizationDynamicsConfig,
)
from experiments.memorization_dynamics.experiment import (
    MemorizationConfig,
    ModelConfig,
    md_experiment,
)
from lib_llm.eval.memorization.prefix_mappings import (
    PrefixEvalConfig,
    PrefixEvalTask,
)
from lib_llm.models.load import load_tokenizer
from lib_llm.models.utils import get_tokenizer_type
from lib_project.experiment import ExperimentID, experiment
from utils.prefix_mappings import setup_replacements


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "prefix_mappings"
EXP_ABBREVIATION = "pm"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    random_data: RandomStringConfig | None
    deterministic_rule_data: DeterministicRuleStringConfig | None
    model: ModelConfig
    memorization: MemorizationConfig
    prefix_testing: PrefixEvalConfig
    prefix_eval_epochs: list[int] = field(default_factory=lambda: [-1])
    replacement_strategy: str = "rand_id"


@dataclass
class ExperimentResult:
    random_data: RandomStringData | None
    deterministic_rule_data: DeterministicRuleStringData | None
    prefix_mappings: pd.DataFrame


@experiment(EXP_NAME)
def pm_experiment(
    config: ExperimentConfig,
    description: ExperimentID,
) -> ExperimentResult:
    tokenizer = load_tokenizer(config.model, max_length=2048)
    if isinstance(config.random_data, ConditionalRandomStringConfig):
        data = create_ngram_conditional_dataset(config.random_data, tokenizer)
    else:
        data_config = get_config(
            config.random_data, config.deterministic_rule_data
        )
        data = load_data(data_config, tokenizer)

    get_replacements = setup_replacements(
        config.prefix_testing,
        replacement_strategy=config.replacement_strategy,
        tokenizer=tokenizer,
        tokenizer_type=get_tokenizer_type(config.model.model_id_not_none),
        replacement_length=len(data.token_ids[0]),
    )
    eval_epochs = (
        [config.memorization.training.args.num_train_epochs]
        if config.prefix_eval_epochs == [-1]
        else config.prefix_eval_epochs
    )
    prefix_eval_task = PrefixEvalTask(
        config=config.prefix_testing,
        data=(data.tokens, data.batch_encoding()),
        tokenizer=tokenizer,
        get_replacements=get_replacements,
        eval_condition=lambda state: state.epoch in eval_epochs,
    )

    mem_dynamics_config = MemorizationDynamicsConfig(
        seed_id=config.seed_id,
        name="override",
        seed=config.seed,
        random_data=config.random_data,
        deterministic_rule_data=config.deterministic_rule_data,
        model=config.model,
        memorization=config.memorization,
    )
    # Call the memorization dynamics experiment to get the trained model
    md_experiment(
        mem_dynamics_config,
        additional_callbacks=[prefix_eval_task],
    )

    prefix_mappings = prefix_eval_task.result()
    print("prefix mappings:", prefix_mappings)

    return ExperimentResult(
        **deunion_data(data),
        prefix_mappings=prefix_mappings,
    )
