import random

from recognizers.random_utils import (
    sample_from_negative_binomial
)
from recognizers.grammars.grammar import Rule
from recognizers.grammars.context_free_grammar import (
    ContextFreeGrammarContainer
)
from recognizers.grammars.cfg_finite_check import is_infinite
from recognizers.grammars.trim_cnf_cfg import trim_cnf_cfg
from recognizers.dataset_generation.weighted_language import FiniteLanguageError

def sample_cnf_cfg(
    mean_num_variables: float,
    mean_num_terminals: float,
    mean_num_lexical_rules: float,
    mean_num_binary_rules: float,
    mean_num_chains: float,
    mean_chain_length: float,
    generator: random.Random
) -> ContextFreeGrammarContainer:
    num_variables = sample_from_negative_binomial(mean_num_variables, 1, generator)
    num_terminals = sample_from_negative_binomial(mean_num_terminals, 1, generator)
    num_lexical_rules = sample_from_negative_binomial(mean_num_lexical_rules, 1, generator)
    num_binary_rules = sample_from_negative_binomial(mean_num_binary_rules, 1, generator)
    G = ContextFreeGrammarContainer(
        num_variables=num_variables,
        num_terminals=num_terminals
    )
    S = G.start_variable()
    non_start_variables = [A for A in G.variables() if A != S]
    all_lexical_rules = [Rule(A, (a,)) for A in G.variables() for a in G.terminals()]
    all_lexical_rules += [Rule(S, ())]
    all_binary_right_hand_sides = [(B, C) for B in non_start_variables for C in non_start_variables]
    all_binary_rules = [Rule(A, rhs) for A in G.variables() for rhs in all_binary_right_hand_sides]
    num_lexical_rules = min(num_lexical_rules, len(all_lexical_rules))
    lexical_rules = generator.choices(all_lexical_rules, k=num_lexical_rules)
    for rule in lexical_rules:
        G.add_rule(rule)
    num_binary_rules = min(num_binary_rules, len(all_binary_rules))
    binary_rules = generator.choices(all_binary_rules, k=min(num_binary_rules, len(all_binary_rules)))
    for rule in binary_rules:
        G.add_rule(rule)
    G = trim_cnf_cfg(G)
    if G.num_variables() > 1:
        num_chains = sample_from_negative_binomial(mean_num_chains, 1, generator)
        for _ in range(num_chains):
            add_self_embedding_rules(G, mean_chain_length, generator)
    print(f"num terminals: {G.num_terminals()}, num variables: {G.num_variables()}, num rules: {len(G.rules())}")
    if not is_infinite(G):
        raise FiniteLanguageError('finite language, not context-free')
    return G

def add_self_embedding_rules(
    G: ContextFreeGrammarContainer,
    mean_chain_length: float,
    generator: random.Random
) -> None:
    chain_length = sample_from_negative_binomial(mean_chain_length, 2, generator)
    num_left = generator.randint(1, chain_length - 1)
    is_left = []
    is_left.extend(True for _ in range(num_left))
    is_left.extend(False for _ in range(chain_length - num_left))
    generator.shuffle(is_left)
    S = G.start_variable()
    non_start_variables = [A for A in G.variables() if A != S]
    A_0 = generator.choice(non_start_variables)
    A = A_0
    for i, is_left_i in enumerate(is_left):
        if i < chain_length - 1:
            B = generator.choice(non_start_variables)
        else:
            B = A_0
        C = generator.choice(non_start_variables)
        if is_left_i:
            right = (C, B)
        else:
            right = (B, C)
        G.add_rule(Rule(A, right))
        A = B

def add_se_rule(cfg, generator):
    se_cfg = cfg

    semiring = Boolean
    one = semiring.one

    lam = 0.5
    k = math.floor(generator.expovariate(lam)) + 2

    k_left = generator.randint(1, k - 1)
    se_positions = [0] * k_left + [1] * (k - k_left)
    generator.shuffle(se_positions)

    nts_list = list(set(se_cfg.V) - {se_cfg.S})
    A = generator.choice(nts_list)

    curr_B = A
    for i, pos in enumerate(se_positions):
        is_last = (i == k - 1)
        new_B = A if is_last else generator.choice(nts_list)
        C = generator.choice(nts_list)

        if pos == 0:
            se_cfg.add(one, curr_B, new_B, C)
        else:
            se_cfg.add(one, curr_B, C, new_B)

        curr_B = new_B

    return se_cfg
