import dataclasses
import math
import random
from collections.abc import Callable, Iterable

import torch
import numpy

from recognizers.tools.linked_list import LinkedList
from recognizers.automata.automaton import State, Symbol, SymbolOrEpsilon
from recognizers.grammars.grammar import (
    Nonterminal,
    Production
)
from recognizers.grammars.context_free_grammar import (
    WeightedContextFreeGrammar
)
from recognizers.dataset_generation.weighted_language import String
from recognizers.grammars.context_free_grammar import WeightedContextFreeGrammar
from recognizers.grammars.context_free_grammar_allsum import context_free_grammar_allsum

Choice = Symbol | tuple[Nonterminal, int]
"""Every choice is either a Symbol or a list of derivations to sample
recursively. Each recursive derivation has a starting nonterminal and 
a length of the string it generates."""

@dataclasses.dataclass
class Actions:
    choices: list[Choice]
    cum_weights: numpy.array

@dataclasses.dataclass
class NormalizedCountingContextFreeGrammar:

    actions: dict[tuple[Nonterminal, int], Actions]
    alphabet_size: int
    num_nonterminals: int
    start_symbol: Nonterminal
    derivation_weights: torch.Tensor
    total_length_weights: torch.Tensor
    max_length: int
    production_weights: Iterable[tuple[Production, torch.Tensor]]

    @staticmethod
    def from_parts(
        G: WeightedContextFreeGrammar[torch.Tensor],
        derivation_weights: torch.Tensor,
    ) -> 'NormalizedCountingContextFreeGrammar':
        max_length = derivation_weights.size(-1)
        alphabet_size = G.alphabet_size()
        num_terminals = G.num_terminals()
        num_nonterminals = G.num_nonterminals()
        start_symbol = G.start_symbol()
        allsum = derivation_weights[G.variable_index(G.start_symbol())]
        production_weights=list(G.production_weights())
        actions = {(X, i): [] for X in range(num_terminals, num_terminals + num_nonterminals) for i in range(max_length + 1)}
        for p, w in production_weights:
            X = p.left_hand_side
            rhs = p.right_hand_side
            match rhs:
                case ():
                    actions[(X, 0)].append((None, w[0]))
                case (a,):
                    actions[(X, 1)].append((a, w[1]))
                case (Y, Z):
                    for i in range(1, max_length):
                        for j in range(0, i + 1):
                            ww = (
                                w[0] +
                                derivation_weights[G.variable_index(Y)][j] +
                                derivation_weights[G.variable_index(Z)][i-j]
                            )
                            actions[(X, i)].append((((Y, j),(Z, i - j)), ww))
                case _:
                    raise ValueError('the grammar is not in Chomsky normal form')
        normalized_actions = normalize_weights(actions)
        return NormalizedCountingContextFreeGrammar(
            actions=normalized_actions,
            alphabet_size=alphabet_size,
            num_nonterminals=num_nonterminals,
            start_symbol=start_symbol,
            total_length_weights=allsum,
            derivation_weights=derivation_weights,
            max_length=max_length,
            production_weights=production_weights
        )

    def accepts_epsilon(self) -> bool:
        # CKY does not work for EPSILON, so this case
        # needs to be handled separately
        for p, _ in self.production_weights:
            X = p.left_hand_side
            rhs = p.right_hand_side
            if (X == self.start_symbol and len(rhs) == 0):
                return True
        return False


    def valid_lengths(self, length_range: tuple[int, int]) -> list[int]:
        lo, hi = length_range
        is_valid = (self.total_length_weights[lo:hi+1] > -math.inf).tolist()
        return [
            l
            for l, l_is_valid in zip(
                range(lo, hi + 1),
                is_valid,
                strict=True
            )
            if l_is_valid
        ]

    def sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> String:
        if include_log_probability:
            log_probability = 0.0
        else:
            log_probability = None
        if include_next_symbols:
            next_symbols = []
        else:
            next_symbols = None
        return tuple(self.sample_derivation(
            self.start_symbol,
            length,
            generator
        )), log_probability, next_symbols

    def total_length_weight(self, length: int) -> float:
        return self.total_length_weights[length].item()

    def sample_derivation(self,
        n: Nonterminal,
        length: int,
        generator: random.Random
    ) -> LinkedList[Symbol]:
        actions = self.actions_for_nonterminal(n, length)
        choices, cum_weights = actions.choices, actions.cum_weights
        index, = generator.choices(
            range(len(choices)),
            cum_weights=cum_weights
        )
        item = choices[index]
        result = []
        if isinstance(item, tuple):
            ((Y, i), (Z, j)) = item
            result += self.sample_derivation(Y, i, generator)
            result += self.sample_derivation(Z, j, generator)
        else:
            if item is not None:
                result.append(item)
        return result

    def actions_for_nonterminal(self, n: Nonterminal, length: int) -> Actions:
        return self.actions[(n, length)]

def normalize_weights(actions):
    normalized_actions = {}
    num_actions = len(actions)
    for i, (X, l) in enumerate(actions):
        if i % 100 == 0:
            print(f'{i} actions out of {num_actions}')
        if len(actions[(X, l)]) > 0:
            items, weights = zip(*actions[(X, l)])
            weights = torch.cumsum(
                torch.softmax(
                    torch.stack(
                        weights, dim=0
                    ), dim=0
                ), dim=0
            )
            weights = weights.to(device='cpu').numpy()
            normalized_actions[(X, l)] = Actions(items, weights)
    return normalized_actions

def push_cnf_cfg_weights(
    G: WeightedContextFreeGrammar[torch.Tensor],
    dtype: torch.dtype,
    device: torch.device
):
    derivation_weights = context_free_grammar_allsum(G, dtype, device)
    print('finished computing allsum')
    return NormalizedCountingContextFreeGrammar.from_parts(
        G=G,
        derivation_weights=derivation_weights
    )
