import dataclasses
import math
import random
from collections.abc import Callable, Iterable

import torch
import numpy

from recognizers.tools.linked_list import LinkedList
from recognizers.automata.reserved import ReservedSymbol
from recognizers.automata.automaton import State, Symbol, SymbolOrEpsilon
from recognizers.automata.pushdown_automaton import (
    StackSymbol,
    PushdownAutomatonTransition,
    WeightedPushdownAutomaton
)
from recognizers.automata.pushdown_automaton_allsum import (
    top_down_pushdown_automaton_allsum
)
from recognizers.string_sampling.weighted_language import String

PopComputation = tuple[State, StackSymbol, State]
Choice = tuple[Symbol, tuple[PopComputation], int|None]
"""Every choice has a generated symbol and list of pop computations to be
recursively sampled. Each recursive pop computation has a target length"""

@dataclasses.dataclass
class Actions:
    choices: list[list[Choice]]
    cum_weights: list[numpy.array]

@dataclasses.dataclass
class NormalizedCountingPushdownAutomaton:

    actions: dict[PopComputation, Actions]
    initial_pop_computation: PopComputation
    num_states: int
    alphabet_size: int
    stack_alphabet_size: int
    initial_state: State
    initial_stack_symbol: Symbol
    pop_computation_weights: torch.Tensor
    total_length_weights: torch.Tensor
    accept_state: State
    max_length: int
    deterministic: bool
    transitions: dict
    transition_weights: Iterable[tuple[PushdownAutomatonTransition, torch.Tensor]]

    @staticmethod
    def from_parts(
        num_states: int,
        initial_state: State,
        alphabet_size: int,
        stack_alphabet_size: int,
        initial_stack_symbol: StackSymbol,
        accept_state: State,
        allsum: torch.Tensor,
        transition_weights: Iterable[tuple[PushdownAutomatonTransition, torch.Tensor]],
        pop_computation_weights: torch.Tensor,
        deterministic: bool
    ) -> 'NormalizedCountingPushdownAutomaton':
        max_length = pop_computation_weights.size(-1)
        actions = {(p, X, q): [] for p in range(num_states) \
                   for X in range(stack_alphabet_size) for q in range(num_states)}
        transitions_by_key = {}
        for t, w in transition_weights:
            a = t.symbol
            p = t.state_from
            r = t.state_to
            X = t.popped_symbol
            gamma = t.pushed_symbols
            key = (p, X, a)
            if key not in transitions_by_key:
                transitions_by_key[key] = []
            transitions_by_key[key].append((r, gamma))
            match gamma:
                case ():
                    actions[p, X, r].append((a, w, (), None))
                case (Y,):
                    for q in range(num_states):
                        ww = w[1] + shift_vector(pop_computation_weights[r, Y, q], 1)
                        actions[p, X, q].append((a, ww, ((r, Y, q),), None))
                case (Y, Z):
                    c = int(a != ReservedSymbol.EPSILON)
                    for i in range(1, max_length-c):
                        for s in range(num_states):
                            for q in range(num_states):
                                ww = (
                                    w[c] +
                                    pop_computation_weights[r, Z, s][i] +
                                    shift_vector(pop_computation_weights[s, Y, q], c + i)
                                )
                                actions[p, X, q].append((a, ww, ((r, Z, s),(s, Y, q)), i))
                case _:
                    raise ValueError('the pushdown automaton is not in normal form')
        actions = actions_by_length(actions, max_length)
        return NormalizedCountingPushdownAutomaton(
            actions=actions,
            num_states=num_states,
            alphabet_size=alphabet_size,
            initial_state=initial_state,
            stack_alphabet_size=stack_alphabet_size,
            initial_stack_symbol=initial_stack_symbol,
            initial_pop_computation=(initial_state, initial_stack_symbol, accept_state),
            accept_state=accept_state,
            total_length_weights=allsum,
            pop_computation_weights=pop_computation_weights,
            max_length=max_length,
            deterministic=deterministic,
            transitions=transitions_by_key,
            transition_weights=transition_weights
        )

    def accepts(self,
                string: String,
                dtype: torch.dtype,
                device: torch.device) -> bool:
        if self.deterministic:
            state = self.initial_state
            stack = [self.initial_stack_symbol]
            index = 0
            while index < len(string):
                if len(stack) < 1:
                    return False
                popped_symbol = stack.pop()
                config = self.transitions.get((state, popped_symbol, string[index]))
                if config is None:
                    config = self.transitions.get((state, popped_symbol, ReservedSymbol.EPSILON))
                    if config is None:
                        return False
                else:
                    index += 1
                state, pushed_symbols = config[0]
                stack.extend(list(pushed_symbols))
            return state == self.accept_state and len(stack) == 0
        elif string == ():
            for t, _ in self.transition_weights:
                a = t.symbol
                p = t.state_from
                r = t.state_to
                X = t.popped_symbol
                gamma = t.pushed_symbols
                if (a == ReservedSymbol.EPSILON and \
                    p == self.initial_state and \
                    r == self.accept_state and \
                    X == self.initial_stack_symbol and \
                    len(gamma) == 0):
                            return True
            return False
        else:
            # This only works for strings != EPSILON
            stringsum = top_down_pushdown_automaton_stringsum(self.transition_weights,
                                                              self.num_states,
                                                              self.stack_alphabet_size,
                                                              dtype,
                                                              device,
                                                              string)
            return stringsum[0, self.initial_state, self.initial_stack_symbol, len(string), self.accept_state]

    def valid_lengths(self, length_range: tuple[int, int]) -> list[int]:
        lo, hi = length_range
        is_valid = (self.total_length_weights[lo:hi+1] > -math.inf).tolist()
        return [
            l
            for l, l_is_valid in zip(
                range(lo, hi + 1),
                is_valid,
                strict=True
            )
            if l_is_valid
        ]

    def sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> String:
        if include_log_probability:
            log_probability = 0.0
        else:
            log_probability = None
        if include_next_symbols:
            next_symbols = []
        else:
            next_symbols = None
        return tuple(self.sample_pop_computation(
            self.initial_pop_computation,
            length,
            generator
        )), log_probability, next_symbols

    def total_length_weight(self, length: int) -> float:
        return self.total_length_weights[length].item()

    def sample_pop_computation(self,
        c: PopComputation,
        length: int,
        generator: random.Random
    ) -> LinkedList[Symbol]:
        actions = self.actions_for_pop_computation(c)
        choices, cum_weights = actions.choices[length], actions.cum_weights[length]
        index, = generator.choices(
            range(len(choices)),
            cum_weights=cum_weights
        )
        symbol, items, i = choices[index]
        result = LinkedList()
        if symbol != ReservedSymbol.EPSILON:
            result.append(symbol)
        match items:
            case (sub_c,):
                sub_length = length - 1
                result.extend(self.sample_pop_computation(sub_c, sub_length, generator))
            case (sub_c1, sub_c2):
                sub_length1 = i
                sub_length2 = length - i - int(symbol != ReservedSymbol.EPSILON)
                result.extend(self.sample_pop_computation(sub_c1, sub_length1, generator))
                result.extend(self.sample_pop_computation(sub_c2, sub_length2, generator))
        return result

    def actions_for_pop_computation(self, c: PopComputation) -> Actions:
        return self.actions[c]

def shift_vector(v, c):
    result = torch.empty_like(v)
    result[:c] = -math.inf
    result[c:] = v[:len(v)-c]
    return result

def actions_by_length(actions, max_length) -> dict[PopComputation, Actions]:
    new_actions = {}
    for c in actions:
        new_actions[c] = Actions([], [])
        for l in range(max_length):
            actions_for_length = [((a, items, i), w) for (a, w, items, i) in actions[c] if i is None or i < l]
            if len(actions_for_length) == 0:
                new_actions[c].choices.append([])
                new_actions[c].cum_weights.append(None)
            else:
                sym_and_items, weights = zip(*actions_for_length)
                weights = torch.cumsum(
                    torch.softmax(
                        torch.stack(
                            weights, dim=0
                        )[:, l], dim=0
                    ), dim=0
                ).numpy()
                new_actions[c].choices.append(sym_and_items)
                new_actions[c].cum_weights.append(weights)

    return new_actions

def push_pushdown_automaton_weights(
    M: WeightedPushdownAutomaton[torch.Tensor],
    dtype: torch.dtype,
    device: torch.device
):
    pop_computation_weights = top_down_pushdown_automaton_allsum(M, dtype, device)
    return NormalizedCountingPushdownAutomaton.from_parts(
        num_states=M.num_states(),
        initial_state=M.initial_state(),
        alphabet_size=M.alphabet_size(),
        stack_alphabet_size=M.stack_alphabet_size(),
        initial_stack_symbol=M.initial_stack_symbol(),
        accept_state=M.accept_state(),
        allsum=pop_computation_weights[M.initial_state(), M.initial_stack_symbol(), M.accept_state()],
        transition_weights=M.transition_weights(),
        pop_computation_weights=pop_computation_weights,
        deterministic=M.deterministic
    )
