

import random

from lbrayuela.string_sampling.weighted_language import (
    String,
    ValidNextSymbolList
)

import torch

#MAX_LENGTH = 254
TGT_SYMBOL = 0


# THIS IS BROKEN, KEPT AS REFERENCE AND SHAMEFUL WARNING
def vanilla_sampler_old(machine,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool,
        # max_length: int = MAX_LENGTH,
    ) -> tuple[String, float | None, ValidNextSymbolList | None]:

    sampled_string = []
    if include_log_probability:
        log_probability = 0.0
    else:
        log_probability = None
    if include_next_symbols:
        next_symbols = []
    else:
        next_symbols = None

    log_probs = []
    states = []
    transitions = []

    def get_prob(weight):
        # maps the counting weight structure to the transition probability
        return torch.exp(weight).sum().item()

    state = machine.initial_state()

    accept_wgts = machine._accept_weights

    trans_to_symbol = {}    

    state_to_trans_map = torch.zeros(machine.num_states(), machine.num_states() + 1)
    for trans, weight in machine._transitions.items():
        prob = get_prob(weight)
        src = trans.state_from
        tgt = trans.state_to
        symbol = trans.symbol
        trans_to_symbol[(src, tgt)] = symbol
        state_to_trans_map[src, tgt] = prob

    # we make the last column the accept weight
    for pos in range(machine.num_states()):
        accept_wgt = get_prob(accept_wgts.get(pos, torch.tensor(-float('inf'))))
        state_to_trans_map[pos, -1] = accept_wgt

    # this is the idx where we store the accept probabilitiy
    accept_state_idx = machine.num_states()

    while True: #len(sampled_string) < max_length:
        # This also accounts for the final state probability,
        # we added it above.
        state_probs = state_to_trans_map[state]
        
        next_state = torch.multinomial(state_probs, 1, generator=generator).item()

        # get the probability of the transition
        prob = state_probs[next_state].item()
        # log the probability
        log_prob = torch.log(torch.tensor(prob))
        log_probs.append(log_prob)

        if next_state == accept_state_idx:
            # We took the accept prob
            break

        symbol = trans_to_symbol[(state, next_state)]

        transitions.append((state, next_state, symbol))

        sampled_string.append(symbol)
           
        states.append(state)
        state = next_state


    #return tuple(sampled_string), torch.stack(log_probs), is_tgt_arc, states, is_tgt_state, is_tgt_symbol
    # make a json
    return {
        "sampled_string": tuple(sampled_string),
        "log_probs": torch.stack(log_probs),
        "states": states,
        "transitions": transitions
    }



def vanilla_sampler(
        machine,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool,
    ) -> dict:

    def weight_to_prob(weight_tensor):
        return torch.exp(weight_tensor).sum().item()

    num_states = machine.num_states()
    outgoing = [[] for _ in range(num_states)]
    accept_prob = [0.0] * num_states

    for tr, w in machine._transitions.items():
        p = weight_to_prob(w)
        outgoing[tr.state_from].append((tr.symbol, tr.state_to, p))

    for s in range(num_states):
        accept_prob[s] = weight_to_prob(machine._accept_weights.get(s, torch.tensor(-float("inf"))))

    sampled_symbols = []
    log_probs       = [] if include_log_probability else None
    next_symbols    = [] if include_next_symbols    else None
    states          = []
    transitions     = []

    state = machine.initial_state()

    while True:
        choices      = outgoing[state]
        probs        = [p for _, _, p in choices] + [accept_prob[state]]
        probs_tensor = torch.tensor(probs)

        idx = torch.multinomial(probs_tensor, 1, generator=generator).item()

        if idx == len(choices):          # took the accept arc
            if include_log_probability:
                log_probs.append(torch.log(probs_tensor[idx]))
            break

        symbol, next_state, p = choices[idx]

        sampled_symbols.append(symbol)
        states.append(state)
        transitions.append((state, next_state, symbol))

        if include_log_probability:
            log_probs.append(torch.log(torch.tensor(p)))

        if include_next_symbols:
            next_symbols.append([sym for sym, _, _ in choices])

        state = next_state

    return {
        "sampled_string": tuple(sampled_symbols),
        "log_probs":      torch.stack(log_probs) if include_log_probability else None,
        "states":         states,
        "transitions":    transitions,
        "next_symbols":   next_symbols,
    }
