from dataclasses import dataclass
from itertools import product

import numpy as np
import pandas as pd
from transformers import PreTrainedTokenizer

from .utils import SyntheticStringConfig, SyntheticStringData, sample_tokens


@dataclass
class DeterministicRuleStringConfig(SyntheticStringConfig):
    string_length: int
    num_strings: int
    premise_length: int

    def __post_init__(self):
        super().__post_init__()

    @property
    def name(self) -> str:
        return (
            "det_"
            + super().alphabet_id
            + f"_l-{self.string_length}_n-{self.num_strings}"
            + f"_prem-{self.premise_length}"
        )


TokenIdRules = dict[tuple[int, ...], int]


@dataclass
class DeterministicRuleStringData(SyntheticStringData):
    config: DeterministicRuleStringConfig
    rules: pd.DataFrame


def generate_deterministic_rule_strings(
    config: DeterministicRuleStringConfig,
    tokenizer: PreTrainedTokenizer,
) -> DeterministicRuleStringData:
    rules = sample_rules(config=config, tokenizer=tokenizer)
    string_token_ids = _sample_strings(config=config, rules=rules)

    token_rules = pd.DataFrame(
        {
            "premise_token_ids": list(rules.rules.keys()),
            "conclusion_token_id": list(rules.rules.values()),
            "premise_tokens": [
                tokenizer.convert_ids_to_tokens(list(premise_token_ids))
                for premise_token_ids in rules.rules.keys()
            ],
            "conclusion_token": [
                tokenizer.convert_ids_to_tokens([conclusion_token_id])[0]
                for conclusion_token_id in rules.rules.values()
            ],
        }
    )

    return DeterministicRuleStringData(
        config=config,
        rules=token_rules,
        raw_token_ids=string_token_ids,
        alphabet_token_ids=rules.alphabet_token_ids,
        start_of_string_token_id=rules.start_of_string_token_id,
        **SyntheticStringData.ids_to_tokens(
            tokenizer=tokenizer,
            token_ids=string_token_ids,
            alphabet_token_ids=rules.alphabet_token_ids,
            start_of_string_token_id=rules.start_of_string_token_id,
        ),
    )


@dataclass
class SampledRules:
    rules: TokenIdRules
    alphabet_token_ids: np.ndarray
    start_of_string_token_id: int | None


def sample_rules(
    config: DeterministicRuleStringConfig,
    tokenizer: PreTrainedTokenizer,
) -> SampledRules:
    num_rules = config.alphabet_size**config.premise_length
    sample_result = sample_tokens(
        config=config,
        tokenizer=tokenizer,
        string_length=1,
        num_strings=num_rules,
    )
    alphabet_token_ids = sample_result.alphabet_token_ids
    rules = {
        tuple(int(tid) for tid in premise_token_ids): int(
            sample_result.token_ids[i][0]
        )
        for i, premise_token_ids in enumerate(
            product(alphabet_token_ids, repeat=config.premise_length)
        )
    }
    return SampledRules(
        rules=rules,
        alphabet_token_ids=sample_result.alphabet_token_ids,
        start_of_string_token_id=sample_result.start_of_string_token_id,
    )


def _sample_strings(
    config: DeterministicRuleStringConfig,
    rules: SampledRules,
) -> np.ndarray:
    rng = np.random.default_rng(config.seed_value + 32)
    prefixes = rng.choice(
        rules.alphabet_token_ids,
        size=(config.num_strings, config.premise_length),
    )

    token_ids = []
    for start_prefix in prefixes:
        string_ids: list[int] = list(start_prefix)
        while len(string_ids) < config.string_length:
            string_ids.append(
                rules.rules[tuple(string_ids[-config.premise_length :])]
            )
        token_ids.append(string_ids)
    return np.array(token_ids, dtype=np.int32)
