import logging
import string
from dataclasses import dataclass
from typing import Iterator, cast

import numpy as np
import pandas as pd
from transformers import PreTrainedTokenizer

from data.synthetic_strings.random import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from data.synthetic_strings.utils import alphabet_encoding
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train
from lib_project.experiment import ExperimentID, experiment

from ..substrings import Substring


logger = logging.getLogger(__name__)

EXP_NAME = "repeated_tokens"
EXP_ABBREVIATION = "rtok"


@dataclass
class TokenConfig:
    prefix_length: int
    num_occurrences: int
    agreement_ratio: float = 1.0


@dataclass
class ModificationConfig:
    seed: int
    tokens: list[TokenConfig]


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: RandomStringConfig
    modification: ModificationConfig
    model: ModelConfig
    training: TrainingConfig


@dataclass
class TokenPrefixPair:
    prefix_token_ids: list[int]
    prefix_tokens: list[str]
    token_1: Substring
    token_2: Substring | None


@dataclass
class ExperimentResult:
    data: RandomStringData
    token_prefix_pairs: list[TokenPrefixPair]
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame


@experiment(EXP_NAME)
def rt_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(config.model)
    data, token_prefix_pairs = construct_token_string(
        config.data,
        config.modification,
        tokenizer,
    )
    dataset = data.dataset()

    eval_task = memorization_dynamics_metrics(
        data.alphabet_tokens,
        data.alphabet_token_ids,
        dataset["test"],
    )
    training_res = train(
        experiment_id,
        (config.model.model_id_not_none, model),
        (config.data.name, dataset),
        config=config.training,
        tokenizer=tokenizer,
        callbacks=[eval_task],
        data_already_preprocessed=True,
    )

    memorization_log = eval_task.result()
    trainig_log = training_res.training_log
    assert trainig_log is not None
    return ExperimentResult(
        data=data,
        token_prefix_pairs=token_prefix_pairs,
        training_log=trainig_log,
        memorization_log=memorization_log,
    )


def construct_token_string(
    data_config: RandomStringConfig,
    modification_config: ModificationConfig,
    tokenizer: PreTrainedTokenizer,
) -> tuple[RandomStringData, list[TokenPrefixPair]]:
    assert data_config.num_partitions == 1
    data = generate_random_strings(data_config, tokenizer)

    special_token_alphabet = string.ascii_uppercase + string.digits[:6]  # 32
    assert len(modification_config.tokens) * 2 <= len(special_token_alphabet)
    special_alphabet_encoding, _ = alphabet_encoding(
        tokenizer,
        data_config.tokenizer_type,
        special_token_alphabet,
    )
    special_alphabet_id_iter = iter(special_alphabet_encoding)

    rng = np.random.default_rng(modification_config.seed)
    available_positions = np.ones(len(data.raw_tokens[0]), dtype=bool)
    token_prefix_pairs = []
    for token_config in modification_config.tokens:
        token_prefix_pair = create_token(
            token_config,
            tokenizer,
            data,
            special_alphabet_id_iter,
            available_positions,
            rng,
        )
        token_prefix_pairs.append(token_prefix_pair)
    assert np.sum(available_positions) == len(data.raw_tokens[0]) - sum(
        conf.num_occurrences * (conf.prefix_length + 1)
        for conf in modification_config.tokens
    )

    data.alphabet_token_ids = np.concatenate(
        [
            data.alphabet_token_ids,
            special_alphabet_encoding,
        ]
    )
    data.alphabet_tokens = data.alphabet_tokens + list(special_token_alphabet)
    data.raw_tokens[0] = tokenizer.convert_ids_to_tokens(data.raw_token_ids[0])
    return data, token_prefix_pairs


def create_token(
    token_config: TokenConfig,
    tokenizer: PreTrainedTokenizer,
    base_data: RandomStringData,
    special_alphabet_id_iter: Iterator[int],
    available_positions: np.ndarray,
    rng: np.random.Generator,
) -> TokenPrefixPair:
    possible_positions = np.arange(
        token_config.prefix_length, len(base_data.raw_tokens[0])
    )
    token_id_prefix = rng.choice(
        base_data.alphabet_token_ids,
        token_config.prefix_length,
        replace=True,
    )

    # We have token 1 and 2 to create agreeing/disagreeing tokens
    # for the same prefix.
    token_1_occurrences = round(
        token_config.num_occurrences * token_config.agreement_ratio
    )
    token_2_occurrences = token_config.num_occurrences - token_1_occurrences

    tokens = []
    for token_occurrences in [token_1_occurrences, token_2_occurrences]:
        if token_occurrences == 0:
            tokens.append(None)
            continue

        token_id = cast(np.ndarray, next(special_alphabet_id_iter)).item()
        token_prefix_ids = np.concatenate(
            [
                token_id_prefix,
                np.array([token_id]),
            ]
        )

        token_positions = []
        for _ in range(token_occurrences):
            while True:
                insertion_position = rng.choice(possible_positions, 1).item()
                insertion_range_with_prefix = np.arange(
                    insertion_position - token_config.prefix_length,
                    insertion_position + 1,
                )
                if np.all(available_positions[insertion_range_with_prefix]):
                    token_positions.append(insertion_position)
                    base_data.raw_token_ids[0][
                        insertion_range_with_prefix
                    ] = token_prefix_ids
                    available_positions[insertion_range_with_prefix] = False
                    break
        tokens.append(
            Substring(
                token_ids=[token_id],
                tokens=[tokenizer.decode(token_id)],
                positions=token_positions,
            )
        )
    return TokenPrefixPair(
        prefix_token_ids=token_id_prefix.tolist(),
        prefix_tokens=[
            tokenizer.decode(token_id) for token_id in token_id_prefix
        ],
        token_1=tokens[0],
        token_2=tokens[1],
    )
