import dataclasses
import math
import random
from collections.abc import Callable, Iterable

import torch

from recognizers.tools.linked_list import LinkedList
from recognizers.automata.reserved import ReservedSymbol
from recognizers.automata.automaton import State, Symbol, SymbolOrEpsilon
from recognizers.automata.pushdown_automaton import (
    StackSymbol,
    PushdownAutomatonTransition,
    WeightedPushdownAutomaton
)
from recognizers.automata.pushdown_automaton_allsum import (
    top_down_pushdown_automaton_allsum,
    top_down_pushdown_automaton_stringsum
)
from recognizers.dataset_generation.weighted_language import String
from recognizers.automata.log_counting_semiring import LogCountingSemiring

PopComputation = tuple[State, StackSymbol, State]
"""Every choice has a generated symbol and list of pop computations to be
recursively sampled. Each recursive pop computation has a target length"""
Choice = tuple[Symbol, tuple[PopComputation, int]]

@dataclasses.dataclass
class Actions:
    weights: torch.Tensor
    cumulative_weights_per_length: list[torch.Tensor]

@dataclasses.dataclass
class NormalizedCountingPushdownAutomaton:

    actions: dict[PopComputation, Actions]
    initial_pop_computation: PopComputation
    num_states: int
    alphabet_size: int
    stack_alphabet_size: int
    initial_state: State
    initial_stack_symbol: Symbol
    pop_computation_weights: torch.Tensor
    total_length_weights: torch.Tensor
    accept_state: State
    max_length: int
    deterministic: bool
    transitions: dict
    transition_weights: Iterable[tuple[PushdownAutomatonTransition, torch.Tensor]]

    @staticmethod
    def from_parts(
        num_states: int,
        initial_state: State,
        alphabet_size: int,
        stack_alphabet_size: int,
        initial_stack_symbol: StackSymbol,
        accept_state: State,
        allsum: torch.Tensor,
        transition_weights: Iterable[tuple[PushdownAutomatonTransition, torch.Tensor]],
        pop_computation_weights: torch.Tensor,
        deterministic: bool,
        dtype: torch.dtype,
        device: torch.device
    ) -> 'NormalizedCountingPushdownAutomaton':
        max_length = pop_computation_weights.size(-1)
        semiring = LogCountingSemiring(size=max_length)
        # size of the alphabet including epsilon
        epsilon_alphabet_size = alphabet_size + 1
        # total number of actions given a pop computation
        # (|Sigma| + 1) + (|Sigma| + 1) * |Gamma| * |Q| + max_length * (|Sigma| + 1) * |Gamma|ˆ2 * |Q|ˆ2
        # For instance, if the given pop computation is [p,X,q] we can sample the transition (action)
        # p,X -- a --> r,Y, followed by the recursive pop computation [r,Y,q].
        # We stack vertically all these actions, starting with the pop actions, followed by replace
        # actions, followed by push actions. Inside the replace block, we sort the actions first by
        # symbol, then stack symbol, then state. Inside the push block, we sort the actions first by
        # i, then the first stack symbol, then the second stack symbol, then first state, then second state.
        num_pop_actions = epsilon_alphabet_size
        num_replace_actions = epsilon_alphabet_size * stack_alphabet_size * num_states
        num_push_actions = (max_length - 1) * epsilon_alphabet_size * stack_alphabet_size ** 2 * num_states ** 2
        num_actions = num_pop_actions + num_replace_actions + num_push_actions
        actions = {(p, X, q): Actions(semiring.zeros(size=(num_actions,), dtype=dtype, device=device), []) \
                   for p in range(num_states) for X in range(stack_alphabet_size) for q in range(num_states)}
        transitions_by_key = {}
        for t, w in transition_weights:
            a = t.symbol
            p = t.state_from
            r = t.state_to
            X = t.popped_symbol
            gamma = t.pushed_symbols
            key = (p, X, a)
            if key not in transitions_by_key:
                transitions_by_key[key] = []
            transitions_by_key[key].append((r, gamma))
            match gamma:
                case ():
                    # Compute index of the action
                    index = a if a != ReservedSymbol.EPSILON else alphabet_size
                    semiring.add_in_place(
                        semiring.transform_tensors(actions[(p,X,r)].weights, lambda x: x[index]),
                        w
                    )
                case (Y,):
                    # Compute number of actions stored before the current one
                    num_before_replace_actions = (a if a != ReservedSymbol.EPSILON else alphabet_size) * stack_alphabet_size * num_states
                    num_before_replace_actions += Y * num_states
                    num_before_replace_actions += r
                    # Compute index of the action
                    index = num_pop_actions + num_before_replace_actions
                    for q in range(num_states):
                        ww = w[1] + shift_vector(pop_computation_weights[r, Y, q], 1)
                        semiring.add_in_place(
                            semiring.transform_tensors(actions[(p, X, q)].weights, lambda x: x[index]),
                            ww
                        )
                        # actions[p, X, q].choices.append((a, ww, ((r, Y, q),), None))
                case (Y, Z):
                    # Compute number of actions stored before the current one
                    num_pop_replace_actions = num_pop_actions + num_replace_actions
                    c = int(a != ReservedSymbol.EPSILON)
                    for i in range(1, max_length-c):
                        # Compute the following only once per i, outside the loop over s
                        num_before_push_actions = (i - 1) * epsilon_alphabet_size * stack_alphabet_size**2 * num_states**2
                        num_before_push_actions += (a if a != ReservedSymbol.EPSILON else alphabet_size) * stack_alphabet_size**2 * num_states**2
                        num_before_push_actions += Y * stack_alphabet_size * num_states**2
                        num_before_push_actions += Z * num_states**2
                        num_before_push_actions += r * num_states
                        for s in range(num_states):
                            index = num_pop_replace_actions + num_before_push_actions + s
                            for q in range(num_states):
                                ww = (
                                    w[c] +
                                    pop_computation_weights[r, Z, s][i] +
                                    shift_vector(pop_computation_weights[s, Y, q], c + i)
                                )
                                semiring.add_in_place(
                                    semiring.transform_tensors(actions[(p, X, q)].weights, lambda x: x[index]),
                                    ww
                                )
                                # actions[p, X, q].choices.append((a, ww, ((r, Z, s),(s, Y, q)), i))
                case _:
                    raise ValueError('the pushdown automaton is not in normal form')

        actions = compute_cumulative_weights_by_length(actions,
                                           num_states,
                                           alphabet_size,
                                           stack_alphabet_size,
                                           max_length,
                                           num_pop_actions,
                                           num_replace_actions)
        return NormalizedCountingPushdownAutomaton(
            actions=actions,
            num_states=num_states,
            alphabet_size=alphabet_size,
            initial_state=initial_state,
            stack_alphabet_size=stack_alphabet_size,
            initial_stack_symbol=initial_stack_symbol,
            initial_pop_computation=(initial_state, initial_stack_symbol, accept_state),
            accept_state=accept_state,
            total_length_weights=allsum,
            pop_computation_weights=pop_computation_weights,
            max_length=max_length,
            deterministic=deterministic,
            transitions=transitions_by_key,
            transition_weights=transition_weights
        )

    def get_choice_from_index(
        self,
        index: int,
        pop_computation: PopComputation,
        length: int
    ) -> Choice:
        p, X, q = pop_computation
        epsilon_alphabet_size = self.alphabet_size + 1
        num_pop_actions = epsilon_alphabet_size
        num_replace_actions = epsilon_alphabet_size * self.stack_alphabet_size * self.num_states
        if length == 0 and index == 0:
            a = ReservedSymbol.EPSILON
            return (a, ())
        if index < num_pop_actions:
            a = index
            if a == self.alphabet_size:
                a = ReservedSymbol.EPSILON
            return (a, ())
        else:
            remainder = index - num_pop_actions
            if remainder < num_replace_actions:
                a = remainder // (self.stack_alphabet_size * self.num_states)
                remainder -= a * self.stack_alphabet_size * self.num_states
                if a == self.alphabet_size:
                    a = ReservedSymbol.EPSILON
                Y = remainder // self.num_states
                remainder -= Y * self.num_states
                r = remainder
                return (a, (((r, Y, q), length - 1),))
            else:
                remainder = index - num_pop_actions - num_replace_actions
                i = remainder // (epsilon_alphabet_size * self.stack_alphabet_size**2 * self.num_states**2)
                remainder -= i * epsilon_alphabet_size * self.stack_alphabet_size**2 * self.num_states**2
                a = remainder // (self.stack_alphabet_size**2 * self.num_states**2)
                remainder -= a * self.stack_alphabet_size**2 * self.num_states**2
                if a == self.alphabet_size:
                    a = ReservedSymbol.EPSILON
                Y = remainder // (self.stack_alphabet_size * self.num_states**2)
                remainder -= Y * self.stack_alphabet_size * self.num_states**2
                Z = remainder // (self.num_states**2)
                remainder -= Z * self.num_states**2
                r = remainder // self.num_states
                remainder -= r * self.num_states
                s = remainder
                return (a, (((r, Z, s), i + 1),((s, Y, q), length - i - 1 - int(a != ReservedSymbol.EPSILON))))

    def accepts(self,
                string: String,
                dtype: torch.dtype,
                device: torch.device) -> bool:
        if self.deterministic:
            state = self.initial_state
            stack = [self.initial_stack_symbol]
            index = 0
            while index < len(string):
                if len(stack) < 1:
                    return False
                popped_symbol = stack.pop()
                config = self.transitions.get((state, popped_symbol, string[index]))
                if config is None:
                    config = self.transitions.get((state, popped_symbol, ReservedSymbol.EPSILON))
                    if config is None:
                        return False
                else:
                    index += 1
                state, pushed_symbols = config[0]
                stack.extend(list(pushed_symbols))
            return state == self.accept_state and len(stack) == 0
        elif string == ():
            for t, _ in self.transition_weights:
                a = t.symbol
                p = t.state_from
                r = t.state_to
                X = t.popped_symbol
                gamma = t.pushed_symbols
                if (a == ReservedSymbol.EPSILON and \
                    p == self.initial_state and \
                    r == self.accept_state and \
                    X == self.initial_stack_symbol and \
                    len(gamma) == 0):
                            return True
            return False
        else:
            # This only works for strings != EPSILON
            stringsum = top_down_pushdown_automaton_stringsum(self.transition_weights,
                                                              self.num_states,
                                                              self.stack_alphabet_size,
                                                              dtype,
                                                              device,
                                                              string)
            return stringsum[0, self.initial_state, self.initial_stack_symbol, len(string), self.accept_state]

    def valid_lengths(self, length_range: tuple[int, int]) -> list[int]:
        lo, hi = length_range
        is_valid = (self.total_length_weights[lo:hi+1] > -math.inf).tolist()
        return [
            l
            for l, l_is_valid in zip(
                range(lo, hi + 1),
                is_valid,
                strict=True
            )
            if l_is_valid
        ]

    def sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> String:
        if include_log_probability:
            log_probability = 0.0
        else:
            log_probability = None
        if include_next_symbols:
            next_symbols = []
        else:
            next_symbols = None
        return tuple(self.sample_pop_computation(
            self.initial_pop_computation,
            length,
            generator
        )), log_probability, next_symbols

    def total_length_weight(self, length: int) -> float:
        return self.total_length_weights[length].item()

    def sample_pop_computation(self,
        c: PopComputation,
        length: int,
        generator: random.Random
    ) -> LinkedList[Symbol]:
        actions = self.actions_for_pop_computation(c)
        cumulative_weights = actions.cumulative_weights_per_length[length]
        index, = generator.choices(
            range(len(cumulative_weights)),
            cum_weights=cumulative_weights
        )
        symbol, items = self.get_choice_from_index(index, c, length)
        result = LinkedList()
        if symbol != ReservedSymbol.EPSILON:
            result.append(symbol)
        match items:
            case (sub_c,):
                sub_c, sub_length = sub_c
                result.extend(self.sample_pop_computation(sub_c, sub_length, generator))
            case (sub_c1, sub_c2):
                sub_c1, sub_length1 = sub_c1
                sub_c2, sub_length2 = sub_c2
                result.extend(self.sample_pop_computation(sub_c1, sub_length1, generator))
                result.extend(self.sample_pop_computation(sub_c2, sub_length2, generator))
        return result

    def actions_for_pop_computation(self, c: PopComputation) -> Actions:
        return self.actions[c]

def shift_vector(v, c):
    result = torch.empty_like(v)
    result[:c] = -math.inf
    result[c:] = v[:len(v)-c]
    return result

def compute_cumulative_weights_by_length(
    actions: dict[PopComputation, Actions],
    num_states: int,
    alphabet_size: int,
    stack_alphabet_size: int,
    max_length: int,
    num_pop_actions: int,
    num_replace_actions: int
) -> dict[PopComputation, Actions]:
    """"Computes cumulative weights of the valid actions for each length
    in [0, max_length]."""
    push_actions_block_size = (alphabet_size + 1) * stack_alphabet_size ** 2 * num_states ** 2
    for pop_computation in actions:
        weights = actions[pop_computation].weights
        for i in range(0, max_length):
            if i == 0:
                cumulative_weights = torch.cumsum(torch.softmax(weights[alphabet_size, i].unsqueeze(dim=0), dim=0), dim=0).numpy()
                actions[pop_computation].cumulative_weights_per_length.append(cumulative_weights)
            elif i == 1:
                cumulative_weights = torch.cumsum(torch.softmax(weights[:alphabet_size+1, i], dim=0), dim=0).numpy()
                actions[pop_computation].cumulative_weights_per_length.append(cumulative_weights)
            else:
                index = num_pop_actions + num_replace_actions + push_actions_block_size * (i - 1)
                cumulative_weights = torch.cumsum(torch.softmax(weights[:index, i], dim=0), dim=0).numpy()
                actions[pop_computation].cumulative_weights_per_length.append(cumulative_weights)

    return actions

def push_pushdown_automaton_weights(
    M: WeightedPushdownAutomaton[torch.Tensor],
    dtype: torch.dtype,
    device: torch.device
):
    pop_computation_weights = top_down_pushdown_automaton_allsum(M, dtype, device)
    return NormalizedCountingPushdownAutomaton.from_parts(
        num_states=M.num_states(),
        initial_state=M.initial_state(),
        alphabet_size=M.alphabet_size(),
        stack_alphabet_size=M.stack_alphabet_size(),
        initial_stack_symbol=M.initial_stack_symbol(),
        accept_state=M.accept_state(),
        allsum=pop_computation_weights[M.initial_state(), M.initial_stack_symbol(), M.accept_state()],
        transition_weights=M.transition_weights(),
        pop_computation_weights=pop_computation_weights,
        deterministic=M.deterministic,
        dtype=dtype,
        device=device
    )
