from lbrayuela.automata.finite_automaton import FiniteAutomatonContainer, FiniteAutomatonTransition

import random
import torch
import math
from collections import defaultdict
import numpy as np
from lbrayuela.automata.automaton import State
from lbrayuela.automata.finite_automaton import (
    FiniteAutomaton,
    WeightedFiniteAutomaton,
    WeightedFiniteAutomatonContainer,
)
from lbrayuela.automata.log_counting_semiring import LogCountingSemiring
from lbrayuela.automata.log_counting_semiring_overflow import OverflowLogCountingSemiring
from lbrayuela.string_sampling.prepare_sampler import get_accept_weight

from lbrayuela.automata.reserved import ReservedSymbol
from lbrayuela.string_sampling.weighted_language import (
    String,
    ValidNextSymbolList
)
from lbrayuela.automata.automaton import Transition

# MAX_LENGTH = 256
TGT_SYMBOL = 0


from lrayuela.fsa.fsa import FSA
from lrayuela.base.semiring import Real
from lrayuela.base.state import State as RayuelaState
from lrayuela.base.symbol import Sym
import torch


def to_rayuela_fsa(
    automaton
):
    result = FSA(R=Real)
    result.set_I(RayuelaState(automaton.initial_state()))

    for trans, weight in automaton._transitions.items():
        r_weight = Real(torch.exp(weight.max()).item())
        result.set_arc(
            RayuelaState(trans.state_from),
            Sym(trans.symbol),
            RayuelaState(trans.state_to),
            r_weight
        )

    for q, w in automaton._accept_weights.items():
        r_w = Real(torch.exp(w.max()).item())
        result.set_F(RayuelaState(q), r_w)
    return result


def counting_sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool,
        # max_length: int = MAX_LENGTH,
        tgt_symbol: int = TGT_SYMBOL
    ) -> tuple[String, float | None, ValidNextSymbolList | None]:

    sampled_string = []
    log_probs = []
    if include_log_probability:
        log_probability = 0.0
    else:
        log_probability = None
    if include_next_symbols:
        next_symbols = []
    else:
        next_symbols = None

    state = self.initial_state
    length_counter = length

    # for logging
    states = []
    transitions = []

    while True: #while len(sampled_string) < max_length:
        
        length_counter = max(length_counter, 0)        

        actions = self.actions_at_state(state)
        if include_next_symbols:
            # The set of next symbols is precomputed for each state.
            next_symbols.append(actions.next_symbols)

        # Randomly sample the next action.

        cum_weights = actions.cumulative_weights[:, length_counter].tolist()
        try:
            index, = generator.choices(range(len(cum_weights)), cum_weights=cum_weights)
        except:
            breakpoint()

        # check the acceptance probabilitiy
        accept_prob = cum_weights[-1]

        if index < len(actions.transitions):
            transition = actions.transitions[index]
            
            # logging
            transitions.append((state, transition.state_to, transition.symbol))

            if transition.symbol == ReservedSymbol.EPSILON:
                # Only deterministic machines
                raise ValueError

            sampled_string.append(transition.symbol)
            if include_log_probability:
                log_probability += transition.log_weight[length_counter].item()
                log_probs.append(transition.log_weight[length_counter].item())
            
            is_state_intervention = self.target_state is not None
            is_transition_intervention = self.target_transition is not None
            is_symbol_intervention = (not is_state_intervention) and (not is_transition_intervention)

            assert sum([is_state_intervention, is_transition_intervention, is_symbol_intervention]) == 1

            if is_symbol_intervention:
                # Target symbol
                if transition.symbol == tgt_symbol:
                    length_counter -= 1
            elif is_transition_intervention:
                # Target arc
                loc_state_from = state
                loc_state_to = transition.state_to
                loc_symbol = transition.symbol
                if self.target_transition.state_from == loc_state_from and self.target_transition.state_to == loc_state_to and self.target_transition.symbol == loc_symbol:
                    length_counter -= 1
            elif is_state_intervention:
                # doesn't matter that it's state_to here
                # and state_from when we target the transitions,
                # since one implies the other.
                # i.e., if we go there, we are going to leave also.
                
                # Target state
                #if transition.state_to == self.target_state:
                if state == self.target_state:
                    length_counter -= 1
            else:
                raise ValueError("Invalid target")

            # Update the current state 
            state = transition.state_to
            states.append(state)

        else:
            # the last action is always accept
            if include_log_probability:
                log_probability += actions.accept_log_weight[length_counter].item()
            break

    return {
        "sampled_string": tuple(sampled_string),
        "log_probability": log_probability,
        "log_probs": log_probs,
        "states": states,
        "transitions": transitions,
        "next_symbols": next_symbols,
    }
    # return tuple(sampled_string), log_probability, next_symbols, log_probs
    



def counting_sample_alo(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool,
        tgt_symbol: int = TGT_SYMBOL
    ) -> tuple[String, float | None, ValidNextSymbolList | None]:

    sampled_string = []
    log_probs = []
    if include_log_probability:
        log_probability = 0.0
    else:
        log_probability = None
    if include_next_symbols:
        next_symbols = []
    else:
        next_symbols = None

    state = self.initial_state
    length_counter = length

    # for logging
    states = []
    transitions = []

    while True:
        
        length_counter = max(length_counter, 0)        

        actions = self.actions_at_state(state)
        if include_next_symbols:
            # The set of next symbols is precomputed for each state.
            next_symbols.append(actions.next_symbols)

        # Randomly sample the next action.
        if length_counter or not length:
            # need to check if the length is 0 to begin with!
            cum_weights = actions.cumulative_weights[:, length_counter].tolist()
        else:
            # we can pick anything
            cum_weights = actions.cumulative_weights_all.tolist()
        try:
            index, = generator.choices(range(len(cum_weights)), cum_weights=cum_weights)
        except:
            breakpoint()

        # check the acceptance probabilitiy
        accept_prob = cum_weights[-1]

        if index < len(actions.transitions):
            transition = actions.transitions[index]
            
            # logging
            transitions.append((state, transition.state_to, transition.symbol))

            if transition.symbol == ReservedSymbol.EPSILON:
                # Only deterministic machines
                raise ValueError

            sampled_string.append(transition.symbol)
            if include_log_probability:
                log_probability += transition.log_weight[length_counter].item()
                log_probs.append(transition.log_weight[length_counter].item())
            
            is_state_intervention = self.target_state is not None
            is_transition_intervention = self.target_transition is not None
            is_symbol_intervention = (not is_state_intervention) and (not is_transition_intervention)

            assert sum([is_state_intervention, is_transition_intervention, is_symbol_intervention]) == 1

            if is_symbol_intervention:
                # Target symbol
                if transition.symbol == tgt_symbol:
                    length_counter -= 1
            elif is_transition_intervention:
                # Target arc
                loc_state_from = state
                loc_state_to = transition.state_to
                loc_symbol = transition.symbol
                if self.target_transition.state_from == loc_state_from and self.target_transition.state_to == loc_state_to and self.target_transition.symbol == loc_symbol:
                    length_counter -= 1
            elif is_state_intervention:
                # doesn't matter that it's state_to here
                # and state_from when we target the transitions,
                # since one implies the other.
                # i.e., if we go there, we are going to leave also.
                
                # Target state
                #if transition.state_to == self.target_state:
                if state == self.target_state:
                    length_counter -= 1
            else:
                raise ValueError("Invalid target")

            # Update the current state 
            state = transition.state_to
            states.append(state)

        else:
            # the last action is always accept
            if include_log_probability:
                log_probability += actions.accept_log_weight[length_counter].item()
            break

    return {
        "sampled_string": tuple(sampled_string),
        "log_probability": log_probability,
        "log_probs": log_probs,
        "states": states,
        "transitions": transitions,
        "next_symbols": next_symbols,
    }


def sample_weights(N, accept_prob=None, np_gen=None):
    """
    Generate N positive values that sum to 1, with the option to fix the last value
    to accept_prob.
    
    Parameters:
    -----------
    N : int
        Number of weights to generate
    accept_prob : float or None
        If provided, the last weight will be set to this value, and the remaining N-1
        weights will sum to (1 - accept_prob)
        
    Returns:
    --------
    numpy.ndarray
        Array of N positive values that sum to 1
    """
    if accept_prob:
        # Validate inputs
        if not (0 <= accept_prob < 1):
            raise ValueError("accept_prob must be between 0 and 1")
        if N < 2:
            print("---saw state with no outgoing transitions, acceptance set to 1")
            return np.ones(N)
            #raise ValueError("N must be at least 2 when accept_prob is specified")
            
        # Generate N-1 values that sum to (1-accept_prob)
        random_values = np_gen.dirichlet(np.ones(N - 1))
        scaled_values = random_values * (1 - accept_prob)
        
        # Add the accept_prob as the last element
        random_values = np.append(scaled_values, accept_prob)
    else:
        # Standard case: generate N values summing to 1
        random_values = np.random.dirichlet(np.ones(N))
    
    # Make sure sums to exactly 1 (can be skipped if numerical precision is not a concern)
    normalized_values = random_values / np.sum(random_values)
    
    return normalized_values


def get_transition_weight(
    transition: Transition,
    log_probability: float,
    weight_size: int,
    dtype: torch.dtype,
    device: torch.device,
    target_symbol: int = 0
) -> torch.Tensor:
    weight = torch.full((weight_size,), -math.inf, dtype=dtype, device=device)
    if transition.symbol == target_symbol:
        index = 1
    else:
        index = 0
    weight[index] = log_probability
    return weight


def get_transition_weight_arc(
    transition: Transition,
    log_probability: float,
    weight_size: int,
    dtype: torch.dtype,
    device: torch.device,
    target_transition: Transition
) -> torch.Tensor:
    weight = torch.full((weight_size,), -math.inf, dtype=dtype, device=device)
    if transition == target_transition:
        index = 1
    else:
        index = 0
    weight[index] = log_probability
    return weight


def get_transition_weight_state(
    transition: Transition,
    log_probability: float,
    weight_size: int,
    dtype: torch.dtype,
    device: torch.device,
    target_state: Transition
) -> torch.Tensor:
    weight = torch.full((weight_size,), -math.inf, dtype=dtype, device=device)
    if transition.state_from == target_state:
        index = 1
    else:
        index = 0
    weight[index] = log_probability
    return weight


def sample_finite_automaton(num_states, alphabet_size, generator, accept_prob):
    M = FiniteAutomatonContainer(num_states=num_states, alphabet_size=alphabet_size)
    for q in range(num_states):
        for a in range(alphabet_size):
            r = generator.choice(range(num_states))
            if generator.random() <= 0.5:
                M.add_transition(FiniteAutomatonTransition(q, a, r))
    
        if generator.random() <= accept_prob:
            M.add_accept_state(q)

    return M



def initialize_weighted_automaton(
    M: FiniteAutomaton,
    dtype: torch.dtype,
    device: torch.device,
    accept_prob=None,
    generator=None
) -> WeightedFiniteAutomaton[torch.Tensor]:
    """
    Initialize a weighted finite automaton with randomly sampled weights.
    
    Args:
        M: The input finite automaton (or weighted)
        max_count: Maximum count for the log counting semiring
        dtype: Torch data type for weights
        device: Torch device for weights
        accept_prob: Probability of accepting (None for random)
        generator: Random number generator
        
    Returns:
        A weighted finite automaton with random weights
    """
    assert generator is not None

    num_states = M.num_states()
    result = WeightedFiniteAutomatonContainer[torch.Tensor](
        num_states=num_states,
        alphabet_size=M.alphabet_size(),
        initial_state=M.initial_state(),
        semiring=LogCountingSemiring(1)
    )
    
    # Group transitions by source state for easier processing
    grouped_transitions = [[] for _ in range(num_states)]
    for t in M.transitions():
        grouped_transitions[t.state_from].append(t)

    # Initialize numpy random generator
    npseed = generator.getrandbits(64)
    np_gen = np.random.default_rng(npseed)

    for state_from, transitions in enumerate(grouped_transitions):
        is_accept_state = M.is_accept_state(State(state_from))
        num_actions = len(transitions) + int(is_accept_state)

        if num_actions > 0:
            # Sample weights so they add up to one
            local_accept = 0.0
            if is_accept_state:
                local_accept = accept_prob

            weights = sample_weights(num_actions, local_accept, np_gen)
            
            # Set transition weights
            for idx, t in enumerate(transitions):
                prob = weights[idx]
                log_prob = torch.tensor(math.log(prob), dtype=dtype, device=device).unsqueeze(dim=0)
                result.set_transition_weight(t, log_prob)

            # Set accept weights if applicable
            if is_accept_state:
                end_prob = weights[-1]
                end_log_prob = torch.tensor(math.log(end_prob), dtype=dtype, device=device).unsqueeze(dim=0)
                result.set_accept_weight(
                    State(state_from),
                    end_log_prob
                )

    return result


def modify_weights_for_target(
    wfa: WeightedFiniteAutomaton[torch.Tensor],
    max_count: int,
    target_symbol: int = 0,
    target_arc: bool = False,
    target_state: bool = False,
    tgt_arc_idx=None,
    tgt_state_idx=None,
    generator=None,
    at_least_once_semiring=False
) -> WeightedFiniteAutomaton[torch.Tensor]:
    """
    Modify the weights of a weighted finite automaton based on a target.
    
    Args:
        wfa: The weighted finite automaton to modify
        max_count: Maximum count for the log counting semiring
        dtype: Torch data type for weights
        device: Torch device for weights
        target_symbol: Symbol to count
        target_arc: Whether to count a specific arc
        target_state: Whether to count a specific state
        tgt_arc_idx: Index of the target arc (if provided)
        tgt_state_idx: Index of the target state (if provided)
        generator: Random number generator
        
    Returns:
        The modified weighted finite automaton
    """
    assert not (target_state and target_arc), "Cannot have both target state and target arc"
    weight_size = max_count + 1

    wgt = list(wfa._transitions.values())[0]
    dtype = wgt.dtype
    device = wgt.device

    result = wfa  # We'll modify the input wfa
    result.target_transition = None
    result.target_state = None

    # set new semiring object
    if at_least_once_semiring:
        result._semiring = OverflowLogCountingSemiring(2)
    else:
        result._semiring = LogCountingSemiring(weight_size)

    # Group transitions by source state for processing
    grouped_transitions = [[] for _ in range(wfa.num_states())]
    for t in wfa.transitions():
        grouped_transitions[t.state_from].append(t)
    
    # Handle target arc if specified
    if target_arc:
        if tgt_arc_idx is None:
            target_transition = generator.choice(list(M.transitions()))
        else:
            # Sort to get deterministic lookup
            keys = sorted(wfa.transitions(), key=lambda t: (t.state_from, t.symbol, t.state_to))
            target_transition = keys[tgt_arc_idx]
        result.target_transition = target_transition

    # Handle target state if specified
    if target_state:
        target_state_val = None
        while target_state_val is None:
            if tgt_state_idx is None:
                cand_state = generator.choice(list(wfa.states()))
            else:
                cand_state = tgt_state_idx

            if len(grouped_transitions[cand_state]) > 0:
                target_state_val = cand_state
        result.target_state = target_state_val

    # Now modify the weights based on the target
    for state_from, transitions in enumerate(grouped_transitions):
        is_accept_state = wfa.is_accept_state(State(state_from))
        
        if len(transitions) > 0 or is_accept_state:
            for t in transitions:
                # Get the current log probability
                log_prob = result._transitions[t][0]
                
                if target_arc:
                    weight = get_transition_weight_arc(t, log_prob, weight_size, dtype, device, result.target_transition)
                elif target_state:
                    weight = get_transition_weight_state(t, log_prob, weight_size, dtype, device, result.target_state)
                else:
                    # Target symbol
                    weight = get_transition_weight(t, log_prob, weight_size, dtype, device, target_symbol)
                    
                result.set_transition_weight(t, weight)
        
        if is_accept_state:
            # the last weight is the accept weight
            end_log_prob = result._accept_weights[state_from][0]

            if target_state:
                weight = get_transition_weight_state(t, end_log_prob, weight_size, dtype, device, result.target_state)
            else:
                weight = get_accept_weight(end_log_prob, weight_size, dtype, device)

            result.set_accept_weight(
                State(state_from),
                weight
            )

    return result


def lift_finite_automaton_random_weights(
    M: FiniteAutomaton,
    max_count: int,
    dtype: torch.dtype,
    device: torch.device,
    target_symbol: int = 0,
    target_arc: bool = False,
    target_state: bool = False,
    tgt_arc_idx=None,
    tgt_state_idx=None,
    accept_prob=None,
    generator=None
) -> WeightedFiniteAutomaton[torch.Tensor]:
    """
    Combined function that calls the two separate functions.
    This maintains backward compatibility.
    """
    # First initialize with random weights
    result = initialize_weighted_automaton(
        M, dtype, device, accept_prob, generator
    )
    
    # Then modify based on target
    return modify_weights_for_target(
        result, max_count, 
        target_symbol, target_arc, target_state,
        tgt_arc_idx, tgt_state_idx, generator
    )


def lift_finite_automaton_random_weights_old(
    M: FiniteAutomaton,
    max_count: int,
    dtype: torch.dtype,
    device: torch.device,
    target_symbol: int = 0,
    target_arc: int = False,
    target_state: int = False,
    tgt_arc_idx=None,    # lazy hacks to mix with existing logic, 
    tgt_state_idx=None,  # flags for using above, flag for index here, used for targeted machine
    accept_prob=None,
    generator=None
) -> WeightedFiniteAutomaton[torch.Tensor]:
    
    # this function works for both target symbol and target arc
    # target symbol is the symbol that we want to count
    # if target arc is true then we sample an arc
    
    assert generator is not None

    assert not (target_state and target_arc), "Cannot have both target state and target arc"

    num_states = M.num_states()
    weight_size = max_count + 1
    result = WeightedFiniteAutomatonContainer[torch.Tensor](
        num_states=num_states,
        alphabet_size=M.alphabet_size(),
        initial_state=M.initial_state(),
        semiring=LogCountingSemiring(weight_size)
    )
    result.target_transition = None
    result.target_state = None

    if target_arc:
        # We randomly sample an arc
        if tgt_arc_idx is None:
            target_transition = generator.choice(list(M.transitions()))
        else:
            # SORT TO GET DETERMINISTIC LOOKUP
            keys = sorted(M.transitions(), key=lambda t: (t.state_from, t.symbol, t.state_to))
            target_transition = keys[tgt_arc_idx]
        result.target_transition = target_transition

    grouped_transitions = [[] for _ in range(num_states)]
    for t in M.transitions():
        # This groups by the source state!
        # Mainly so we can account for the source state having
        # being an accept state or not.
        grouped_transitions[t.state_from].append(t)

    if target_state:
        # We randomly sample a state
        # We make sure it has transitions going out!

        # TODO: This is maybe wrong, we should count the time it goes into the state,
        #        we should also not filter by if there is an outgoing state?
        target_state = None
        while target_state is None:
            if tgt_state_idx is None:
                cand_state = generator.choice(list(M.states()))
            else:
                cand_state = tgt_state_idx

            if len(grouped_transitions[cand_state]) > 0:
                target_state = cand_state
        result.target_state = target_state    

    npseed = generator.getrandbits(64)
    np_gen = np.random.default_rng(npseed)

    for state_from, transitions in enumerate(grouped_transitions):
        # loop over state and its outgoing transitions!

        is_accept_state = M.is_accept_state(State(state_from))
        num_actions = len(transitions) + int(is_accept_state)

        # todo: move this to machine sampling
        # we sample the weights so they add up to one,
        # making sure we only include the accept state probs
        # if it is an accept state.

        weights = sample_weights(num_actions, accept_prob, np_gen)

        if num_actions > 0:
            for idx, t in enumerate(transitions):
                prob = weights[idx]
                log_prob = math.log(prob)

                if target_arc:
                    weight = get_transition_weight_arc(t, log_prob, weight_size, dtype, device, target_transition)
                elif target_state:
                    weight = get_transition_weight_state(t, log_prob, weight_size, dtype, device, target_state)
                else:
                    # Target symbol
                    weight = get_transition_weight(t, log_prob, weight_size, dtype, device, target_symbol)
                    
                result.set_transition_weight(
                    t,
                    weight
                )

            if is_accept_state:
                # the last weight is the accept weight
                end_prob = weights[-1]
                end_log_prob = math.log(end_prob)
                result.set_accept_weight(
                    State(state_from),
                    get_accept_weight(end_log_prob, weight_size, dtype, device)
                )
    return result
    

def construct_weighted_automaton(
    max_count: int,
    dtype: torch.dtype,
    device: torch.device
) -> WeightedFiniteAutomatonContainer:
    num_states = 7
    alphabet_size = 3
    weight_size = max_count + 1
    initial_state = State(6)
    accept_state = State(0)

    automaton = WeightedFiniteAutomatonContainer(
        num_states=num_states,
        alphabet_size=alphabet_size,
        initial_state=initial_state,
        semiring=LogCountingSemiring(weight_size)
    )
    
    accept_weight = torch.full((weight_size,), float('-inf'), dtype=dtype, device=device)
    accept_weight[0] = math.log(0.3)
    automaton.set_accept_weight(0, accept_weight)
    
    transitions = [
        (6, 0, 4, 1.0),
        (4, 1, 1, 0.2),
        (4, 2, 3, 0.7),
        (4, 0, 5, 0.1),
        (1, 0, 0, 1.0),
        (0, 1, 1, 0.7),
        (2, 0, 0, 0.5),
        (2, 1, 1, 0.2),
        (2, 2, 3, 0.3),
        (3, 0, 2, 1.0),
        (5, 0, 6, 0.9),
        (5, 1, 3, 0.1),
    ]
    
    for from_state, symbol, to_state, prob in transitions:
        transition = FiniteAutomatonTransition(from_state, symbol, to_state)
        weight = torch.full((weight_size,), float('-inf'), dtype=dtype, device=device)

        if symbol == TGT_SYMBOL:
            weight[1] = math.log(prob)
        else:
            weight[0] = math.log(prob)

        automaton.set_transition_weight(transition, weight)
    
    return automaton


def sanity_check_machine_dist(A_l):
    grouped_probs = defaultdict(list)
    for t in A_l.transitions():
        weight = A_l._transitions[t]
        prob = torch.exp(weight)
        # set inf to zero
        prob[torch.isinf(prob)] = 0
        # get the non zero index

        assert len(torch.nonzero(prob)) == 1
        nonzero_idx = torch.nonzero(prob).item()
        prob = prob[nonzero_idx]
        grouped_probs[t.state_from].append(prob)

    seen_accept = False
    for state_from, g_probs in grouped_probs.items():
        # assert sums to 1
        prob_sum = sum(g_probs)
        is_accept = A_l.is_accept_state(State(state_from))
    if not is_accept:
        assert math.isclose(prob_sum, 1, rel_tol=1e-5)
    else:
        if seen_accept:
            raise ValueError("Only one accept state should be present")
        seen_accept = True
    

def logsumexp_normalize(x, dim=0, is_log=True):
    if not is_log:
        x = torch.log(x)
    max_x = x.max(dim=dim, keepdim=True)[0]
    normalized = x - (max_x + torch.log(torch.sum(torch.exp(x - max_x), dim=dim, keepdim=True)))
    return normalized if is_log else torch.exp(normalized)


def sample_length(Z, K, N, k=0, m=0, Z_pov=None, torch_gen=None):
    power = K - k - 1
    
    # Calculate Z_pov in log space
    if Z_pov is None:
        # make it a semiring element
        Z_pov = torch.full((K + 1, len(Z)), float('-inf')).to(Z.device)
        Z_pov[0][0] = 0.0
        for i in range(1, K + 1):
            Z_pov[i] = LogCountingSemiring.multiply(Z_pov[i-1], Z)
    
    probs = torch.full((N + 1,), float('-inf'))

    for n in range(N + 1):
        pov_key = N - m - n
        if pov_key >= 0:
            try:
                probs[n] = Z[n] + Z_pov[power][pov_key]
            except:
                breakpoint()
    
    # Convert from log probabilities to probabilities and normalize
    probs_norm = torch.softmax(probs, dim=0)

    try:
        num = torch.multinomial(probs_norm, 1, generator=torch_gen).item()
    except:
        breakpoint()

    return num, Z_pov


def alo_sample_length(Z, S, T, k=0, m=0, Z_pov=None, torch_gen=None):
    """
    Modified function to sample based on the binomial probability, using only torch operations.
    Maintains mathematical precision without arbitrary fallbacks.
    
    Args:
        Z: Tensor where Z[0] is log prob of seeing event 0 times, Z[1] is log prob of seeing ≥1 times
        S: Total number of strings to sample
        T: Target number of strings that should have ≥1 occurrence
        k: Current position in sampling (default 0)
        m: Current count of strings with ≥1 occurrence (default 0)
        Z_pov: Pre-computed tensor (default None)
        torch_gen: Torch random generator (default None)
        
    Returns:
        sampled: 0 or 1 indicating whether current string has ≥1 occurrence
        Z_pov: Pre-computed tensor for efficiency
    """
    import torch
    
    # Calculate binomial coefficient using log space for numerical stability
    def log_binom(n, k):
        if k < 0 or k > n:
            return float('-inf')  # log(0)
        if k == 0 or k == n:
            return 0.0  # log(1)
            
        # log(n choose k) = log(n!) - log(k!) - log((n-k)!)
        return torch.lgamma(torch.tensor(n + 1.0, device=Z.device)) - \
               torch.lgamma(torch.tensor(k + 1.0, device=Z.device)) - \
               torch.lgamma(torch.tensor(n - k + 1.0, device=Z.device))
    
    # If we're processing the last string
    if k == S - 1:
        # If we already have T strings with ≥1 occurrence, we need 0 occurrences
        if m == T:
            return 0, Z_pov
        # If we have T-1 strings with ≥1 occurrence, we need 1 occurrence
        elif m == T - 1:
            return 1, Z_pov
        # If it's impossible to reach T strings, default to 0
        else:
            return 0, Z_pov
    
    # Number of strings remaining to process after this one
    remaining = S - k - 1
    
    # Get log probabilities (Z is already in log space)
    log_p0 = Z[0]  # Log probability of 0 occurrences
    log_p1 = Z[1]  # Log probability of ≥1 occurrences
    
    # Calculate log of conditional probabilities in log space to avoid numerical issues
    
    # Option 1: Current string has 0 occurrences
    # Then we need T-m strings with ≥1 occurrence out of the remaining strings
    if T - m <= remaining:
        # log(binom(remaining, T-m)) + (T-m)*log(p1) + (remaining-(T-m))*log(1-p1)
        log_binom_0 = log_binom(remaining, T-m)
        log_prob_if_0 = log_binom_0 + \
                        (T-m) * log_p1 + \
                        (remaining-(T-m)) * torch.log1p(-torch.exp(log_p1))
    else:
        # Impossible to reach T with remaining strings
        log_prob_if_0 = torch.tensor(float('-inf'), device=Z.device)
    
    # Option 2: Current string has ≥1 occurrences
    # Then we need T-m-1 strings with ≥1 occurrence out of the remaining strings
    if T - m - 1 <= remaining and T - m - 1 >= 0:
        log_binom_1 = log_binom(remaining, T-m-1)
        log_prob_if_1 = log_binom_1 + \
                        (T-m-1) * log_p1 + \
                        (remaining-(T-m-1)) * torch.log1p(-torch.exp(log_p1))
    else:
        # Impossible scenario
        log_prob_if_1 = torch.tensor(float('-inf'), device=Z.device)
    
    # Calculate joint log probabilities
    log_joint_0 = log_p0 + log_prob_if_0
    log_joint_1 = log_p1 + log_prob_if_1
    
    # Use log-sum-exp trick for numerical stability
    log_norm = torch.logsumexp(torch.stack([log_joint_0, log_joint_1]), dim=0)
    
    # Calculate normalized log probabilities
    log_norm_prob_0 = log_joint_0 - log_norm
    log_norm_prob_1 = log_joint_1 - log_norm
    
    # Convert back to probability space
    norm_prob_0 = torch.exp(log_norm_prob_0)
    norm_prob_1 = torch.exp(log_norm_prob_1)
    
    # Create tensor for sampling
    probs_norm = torch.stack([norm_prob_0, norm_prob_1])
    
    # Sample from the distribution
    # Check generator device and handle device mismatch
    if torch_gen is not None and probs_norm.is_cuda and not torch_gen.device.type == 'cuda':
        # If tensor is on CUDA but generator is on CPU, move tensor to CPU for sampling
        cpu_probs = probs_norm.cpu()
        sampled = torch.multinomial(cpu_probs, 1, generator=torch_gen).item()
    elif torch_gen is not None and not probs_norm.is_cuda and torch_gen.device.type == 'cuda':
        # If tensor is on CPU but generator is on CUDA, use a new CPU generator
        cpu_gen = torch.Generator(device='cpu')
        if hasattr(torch_gen, 'get_state'):
            # Try to maintain some randomness by using the CUDA generator's state
            cpu_seed = int(torch_gen.get_state().sum().item()) % (2**32)
            cpu_gen.manual_seed(cpu_seed)
        sampled = torch.multinomial(probs_norm, 1, generator=cpu_gen).item()
    else:
        # Device types match, use as is
        sampled = torch.multinomial(probs_norm, 1, generator=torch_gen).item()
    
    return sampled, Z_pov


def probability_t_out_of_s(Z, S, T):
    """
    Calculate probability that exactly T out of S strings have at least one occurrence of an event.
    Uses log space for numerical stability.
    
    Args:
        Z: Tensor where Z[0] is log prob of seeing event 0 times, Z[1] is log prob of seeing ≥1 times
        S: Total number of strings sampled
        T: Number of strings that should have ≥1 occurrence
        
    Returns:
        Probability (float) of exactly T out of S strings having ≥1 occurrence
    """
    import torch
    
    # Log binomial coefficient calculation
    def log_binom(n, k):
        if k < 0 or k > n:
            return float('-inf')  # log(0)
        if k == 0 or k == n:
            return 0.0  # log(1)
            
        return torch.lgamma(torch.tensor(n + 1.0, device=Z.device)) - \
               torch.lgamma(torch.tensor(k + 1.0, device=Z.device)) - \
               torch.lgamma(torch.tensor(n - k + 1.0, device=Z.device))
    
    # Log probability of a single string having ≥1 occurrence
    log_p = Z[1]
    
    # Log probability of a single string having 0 occurrences
    log_q = Z[0]
    
    # Calculate log of binomial probability: log(C(S,T)) + T*log(p) + (S-T)*log(1-p)
    log_binom_coef = log_binom(S, T)
    log_probability = log_binom_coef + T * log_p + (S-T) * log_q
    
    # Convert back to probability space
    probability = torch.exp(log_probability)
    
    return probability.item() if isinstance(probability, torch.Tensor) else probability


def get_sample_lengths(Z, K, N, generator, at_least_once_semiring=False):
    # Z is the backwards paths
    # K is the total number of strings
    # N is the total number of occurrences
    #
    # at_least_once_semiring means we count occurrences only once per sample
    lengths = []
    seen = 0
    Z_pov = None
    for k in range(K):
        # we have sampled the occurrences for k strings
        # we have seen {seen} occurrences in those k strings
        if at_least_once_semiring:
            length, Z_pov = alo_sample_length(Z, K, N, k, seen, Z_pov, generator)
        else:
            length, Z_pov = sample_length(Z, K, N, k, seen, Z_pov, generator)
        seen += length
        lengths.append(length)
    if sum(lengths) != N:
        print(f"found {sum(lengths)} in total but expected {N}")
    return lengths
