import logging
from dataclasses import dataclass, replace

import numpy as np
import pandas as pd
import torch
from datasets import concatenate_datasets
from transformers import PreTrainedModel, PreTrainedTokenizer

from data.synthetic_strings.random import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from defs import LLMExperimentConfig
from experiments.memorization_dynamics.experiment import (
    ExperimentConfig as MemorizationDynamicsConfig,
)
from experiments.memorization_dynamics.experiment import (
    MemorizationConfig,
    ModelConfig,
    md_experiment,
)
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.eval.metrics import TokenEvaluationTask
from lib_llm.models.load import load_tokenizer
from lib_project.experiment import ExperimentID, experiment


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "in_context_effect"
EXP_ABBREVIATION = "icef"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    training_data: RandomStringConfig
    additional_eval_data: list[RandomStringConfig]
    model: ModelConfig
    memorization: MemorizationConfig


@dataclass
class ExperimentResult:
    trainig_data: RandomStringData
    additional_eval_data: list[RandomStringData]
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame


@experiment(EXP_NAME)
def icef_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    tokenizer = load_tokenizer(config.model, max_length=2048)
    eval_task, additional_eval_data = _get_eval_task(
        config.training_data,
        config.additional_eval_data,
        tokenizer,
    )

    mem_dynamics_config = MemorizationDynamicsConfig(
        seed_id=config.seed_id,
        name="override",
        seed=config.seed,
        random_data=config.training_data,
        deterministic_rule_data=None,
        model=config.model,
        memorization=config.memorization,
    )
    # Call the memorization dynamics experiment to get the trained model
    mem_res = md_experiment(
        mem_dynamics_config,
        override_memorization_eval_task=eval_task,
    )

    training_data = mem_res.random_data
    assert training_data is not None
    return ExperimentResult(
        trainig_data=training_data,
        additional_eval_data=additional_eval_data,
        training_log=mem_res.training_log,
        memorization_log=mem_res.memorization_log,
    )


def _get_eval_task(
    training_data_config: RandomStringConfig,
    additional_eval_data_configs: list[RandomStringConfig],
    tokenizer: PreTrainedTokenizer,
) -> tuple[TokenEvaluationTask, list[RandomStringData]]:
    training_data = generate_random_strings(training_data_config, tokenizer)
    # Create another set of training data with the same config but
    # a different seed
    alternative_training_data = generate_random_strings(
        replace(
            training_data_config,
            seed=training_data_config.seed_value + 10000,
            seed_id=None,
        ),
        tokenizer,
    )
    additional_eval_data = [
        generate_random_strings(eval_data_config, tokenizer)
        for eval_data_config in additional_eval_data_configs
    ]
    eval_data = [
        training_data,
        alternative_training_data,
    ] + additional_eval_data

    alphabet_tokens = set()
    alphabet_token_ids = set()
    for data in eval_data:
        alphabet_tokens.update(data.alphabet_tokens)
        alphabet_token_ids.update(data.alphabet_token_ids)
    combined_datasets = concatenate_datasets(
        [data.dataset()["test"] for data in eval_data]
    )

    eval_task = memorization_dynamics_metrics(
        list(alphabet_tokens),
        np.array(list(alphabet_token_ids), dtype=int),
        combined_datasets,
    )
    return eval_task, additional_eval_data
