import numpy as np
import random
from config import alphabet, alphabet_idx, Hs, alphas, num_reps, Ns
from .token_sampler import sample_edge_noise

class JSSampler:
    def __init__(self, piref, horizon, estimated_value, edge_noise):
        self.piref = piref
        self.horizon = horizon
        self.estimated_value = estimated_value
        self.edge_noise = edge_noise

    def get_weight(self, seq):

        f_val = self.estimated_value(seq)

        if len(seq) == 0:
            return f_val

        prefix = seq[:-1]
        last_token = seq[-1]
        noise_factor = self.edge_noise.get((tuple(prefix), last_token), 1.0)
        return noise_factor * f_val


    def forward(self, sequence):

        choices = []
        probs = []
        #choices = [sequence] ##to make it lazy
        #probs = [0.5] ##to make it lazy

        if len(sequence) > 0:
            choices.append(sequence[:-1])
            probs.append( self.piref.sequence_prob(sequence) * self.get_weight(sequence))

        if len(sequence) < self.horizon:
            for a in self.piref.alphabet:
                next_seq = sequence + [a]
                choices.append(next_seq)
                probs.append(self.piref.sequence_prob(next_seq) * self.get_weight(next_seq))

        probs = np.array(probs, dtype=np.float64)
        total = probs.sum()
        if total == 0 or np.isnan(total):
            return sequence
        probs /= total
        return choices[np.random.choice(len(choices), p=probs)]


    def sample(self, min_steps=5):
        sequence = []
        steps=0
        for _ in range(min_steps):
            steps+=1
            sequence = self.forward(sequence)

        while len(sequence) < self.horizon:
            steps+=1
            sequence = self.forward(sequence)
        return sequence,steps
