import logging
import string
from dataclasses import dataclass
from typing import Iterator, cast

import numpy as np
import pandas as pd
from transformers import PreTrainedTokenizer

from data.synthetic_strings.random import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from data.synthetic_strings.utils import alphabet_encoding
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingConfig, train
from lib_project.experiment import ExperimentID, experiment

from ..substrings import Substring


logger = logging.getLogger(__name__)

EXP_NAME = "repeated_strings"
EXP_ABBREVIATION = "rs"


@dataclass
class SubstringConfig:
    seed: int
    length: int
    num_distinct_substrings: int
    placement_order: str = "random"  # random, consecutive, iterative


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: RandomStringConfig
    substrings: SubstringConfig
    model: ModelConfig
    training: TrainingConfig


@dataclass
class ExperimentResult:
    data: RandomStringData
    substrings: list[Substring]
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame


@experiment(EXP_NAME)
def sss_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(config.model)
    data, substrings = create_repeated_substring_data(
        config.data,
        config.substrings,
        tokenizer,
    )
    dataset = data.dataset()

    eval_task = memorization_dynamics_metrics(
        data.alphabet_tokens,
        data.alphabet_token_ids,
        dataset["test"],
    )
    training_res = train(
        experiment_id,
        (config.model.model_id_not_none, model),
        (config.data.name, dataset),
        config=config.training,
        tokenizer=tokenizer,
        callbacks=[eval_task],
        data_already_preprocessed=True,
    )

    memorization_log = eval_task.result()
    trainig_log = training_res.training_log
    assert trainig_log is not None
    return ExperimentResult(
        data=data,
        substrings=substrings,
        training_log=trainig_log,
        memorization_log=memorization_log,
    )


def create_repeated_substring_data(
    data_config: RandomStringConfig,
    substring_config: SubstringConfig,
    tokenizer: PreTrainedTokenizer,
) -> tuple[RandomStringData, list[Substring]]:
    total_distinct_length = (
        substring_config.num_distinct_substrings * substring_config.length
    )
    assert data_config.num_tokens % total_distinct_length == 0
    rng = np.random.default_rng(substring_config.seed)

    base_substring_data, shuffled_substrings = _create_substring_prototypes(
        data_config,
        substring_config,
        tokenizer,
        rng,
    )
    data, substring_positions = _place_substrings(
        shuffled_substrings,
        num_repetitions=data_config.num_tokens // total_distinct_length,
        placement_order=substring_config.placement_order,
        rng=rng,
    )
    substrings = [
        Substring(
            token_ids=shuffled_substrings[substring_index].tolist(),
            tokens=tokenizer.convert_ids_to_tokens(
                shuffled_substrings[substring_index].tolist()
            ),
            positions=substring_positions[substring_index],
        )
        for substring_index in range(substring_config.num_distinct_substrings)
    ]

    token_ids = np.concatenate(data)
    tokens = tokenizer.convert_ids_to_tokens(token_ids.tolist())
    return (
        RandomStringData(
            config=data_config,
            raw_token_ids=token_ids.reshape(1, -1),
            raw_tokens=[tokens],
            alphabet_token_ids=base_substring_data.alphabet_token_ids,
            alphabet_tokens=base_substring_data.alphabet_tokens,
            start_of_string_token_id=base_substring_data.start_of_string_token_id,
            start_of_string_token=base_substring_data.start_of_string_token,
        ),
        substrings,
    )


def _create_substring_prototypes(
    data_config: RandomStringConfig,
    substring_config: SubstringConfig,
    tokenizer: PreTrainedTokenizer,
    rng: np.random.Generator,
) -> tuple[RandomStringData, list[np.ndarray]]:
    # Create shuffled substrings, i.e. substrings with the same entropy
    # as the base substring
    string_prototypes = []
    string_prototype_ids = []
    tried_samples = 0
    while len(string_prototypes) < substring_config.num_distinct_substrings:
        string_seed = rng.integers(0, 2**32 - 1)
        string_config = RandomStringConfig(
            tokenizer_type=data_config.tokenizer_type,
            seed=string_seed,
            num_tokens=substring_config.length,
            alphabet_size=data_config.alphabet_size,
            entropy_like=data_config.entropy_like,
        )
        prototype_data = generate_random_strings(string_config, tokenizer)
        prototype_token_ids = prototype_data.raw_token_ids[0]
        # shuffled_substring = rng.permutation(
        #     base_substring_data.raw_token_ids[0]
        # )
        if not any(
            (prototype_token_ids == ss).all() for ss in prototype_token_ids
        ):
            # Only add the substring if it is unique
            string_prototypes.append(prototype_data)
            string_prototype_ids.append(prototype_token_ids)
        else:
            tried_samples += 1
            if tried_samples > 1000 * substring_config.num_distinct_substrings:
                raise RuntimeError(
                    "Could not find enough distinct shuffled substrings."
                )
    return string_prototypes[0], string_prototype_ids


def _place_substrings(
    shuffled_substrings: list[np.ndarray],
    num_repetitions: int,
    placement_order: str,
    rng: np.random.Generator,
) -> tuple[list[np.ndarray], list[list[int]]]:
    """Create the final data by concatenating the shuffled substrings"""
    num_distinct_substrings = len(shuffled_substrings)
    remaining_repetitions = np.full(num_distinct_substrings, num_repetitions)
    substring_indices = np.arange(num_distinct_substrings)
    num_total_substrings = num_repetitions * num_distinct_substrings
    substring_length = shuffled_substrings[0].shape[0]

    data = []
    substring_positions = [[] for _ in range(num_distinct_substrings)]
    order_index = 0
    while len(data) < num_total_substrings:
        if placement_order == "random":
            available_substrings = substring_indices[remaining_repetitions > 0]
            substring_index = rng.choice(available_substrings, 1).item()
        elif placement_order == "consecutive":
            substring_index = order_index
            if remaining_repetitions[order_index] == 1:
                # We placed all strings of this type, move on to the next
                order_index += 1
        elif placement_order == "iterative":
            substring_index = order_index
            order_index = (order_index + 1) % num_distinct_substrings
        else:
            raise ValueError(f"Unknown placement order: {placement_order}")

        # The starting position of the substring occurrence in the final
        # string
        substring_positions[substring_index].append(
            len(data) * substring_length
        )
        data.append(shuffled_substrings[substring_index])
        remaining_repetitions[substring_index] -= 1
    return data, substring_positions
