import logging
from dataclasses import dataclass, replace

import numpy as np
import pandas as pd
from transformers import PreTrainedTokenizer

from data.synthetic_strings.random import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train
from lib_project.experiment import ExperimentID, experiment

from ..substrings import Substring


logger = logging.getLogger(__name__)

EXP_NAME = "icl_memorization_conflict"
EXP_ABBREVIATION = "imc"


@dataclass
class RepetitionConflictConfig:
    seed: int
    num_clean_repetitions: int
    num_conflicting_tokens: int
    conflicts_first: bool = False


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: RandomStringConfig
    repetition_conflicts: RepetitionConflictConfig
    model: ModelConfig
    training: TrainingConfig


@dataclass
class ExperimentResult:
    data: RandomStringData
    conflicting_tokens: list[Substring]
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame


@experiment(EXP_NAME)
def irc_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(config.model)
    data, conflicting_tokens = create_repeated_icl_conflict_data(
        config.data,
        config.repetition_conflicts,
        tokenizer,
    )
    dataset = data.dataset()

    eval_task = memorization_dynamics_metrics(
        data.alphabet_tokens,
        data.alphabet_token_ids,
        dataset["test"],
    )
    training_res = train(
        experiment_id,
        (config.model.model_id_not_none, model),
        (config.data.name, dataset),
        config=config.training,
        tokenizer=tokenizer,
        callbacks=[eval_task],
        data_already_preprocessed=True,
    )

    memorization_log = eval_task.result()
    trainig_log = training_res.training_log
    assert trainig_log is not None
    return ExperimentResult(
        data=data,
        conflicting_tokens=conflicting_tokens,
        training_log=trainig_log,
        memorization_log=memorization_log,
    )


def create_repeated_icl_conflict_data(
    data_config: RandomStringConfig,
    repetition_conflict_config: RepetitionConflictConfig,
    tokenizer: PreTrainedTokenizer,
) -> tuple[RandomStringData, list[Substring]]:
    base_data = generate_random_strings(data_config, tokenizer)
    base_token_ids = base_data.raw_token_ids[0]

    if repetition_conflict_config.num_clean_repetitions > 0:
        repeated_token_ids = np.concatenate(
            [base_token_ids] * repetition_conflict_config.num_clean_repetitions
        )
    else:
        repeated_token_ids = np.empty(0)

    rng = np.random.default_rng(repetition_conflict_config.seed)
    conflicting_token_ids = rng.choice(
        base_data.alphabet_token_ids,
        size=repetition_conflict_config.num_conflicting_tokens,
        replace=True,
    )
    insertion_positions = rng.choice(
        len(base_token_ids),
        size=repetition_conflict_config.num_conflicting_tokens,
        replace=False,
    )
    conflicted_copy = base_token_ids.copy()
    confliciting_tokens = []
    if repetition_conflict_config.conflicts_first:
        position_offset = 0
    else:
        position_offset = len(repeated_token_ids)
    for token_id, pos in zip(conflicting_token_ids, insertion_positions):
        conflicted_copy[pos] = token_id
        confliciting_tokens.append(
            Substring(
                token_ids=[token_id.item()],
                tokens=tokenizer.convert_ids_to_tokens([token_id.item()]),
                positions=[pos.item() + position_offset],
            )
        )

    if repetition_conflict_config.conflicts_first:
        token_ids = np.concatenate([conflicted_copy, repeated_token_ids])
    else:
        token_ids = np.concatenate([repeated_token_ids, conflicted_copy])

    repeated_config = replace(
        data_config,
        num_tokens=len(token_ids),
    )
    return (
        RandomStringData(
            config=repeated_config,
            raw_token_ids=token_ids.reshape(1, -1),
            raw_tokens=[tokenizer.convert_ids_to_tokens(token_ids.tolist())],
            alphabet_token_ids=base_data.alphabet_token_ids,
            alphabet_tokens=base_data.alphabet_tokens,
            start_of_string_token_id=base_data.start_of_string_token_id,
            start_of_string_token=base_data.start_of_string_token,
        ),
        confliciting_tokens,
    )
