import numpy as np
import torch
from lm import conditional_nn_generate
import time

def sample_with_weights(choices, probs):
    probs = np.array(probs)
    total = sum(probs)
    assert total != 0 and not np.isnan(total), \
        f"Transition probabilies at {sequence} were {probs}"
        #return sequence  # Stay at current sequence if no valid options
    
    probs /= total
    
    # Sample next sequence
    chosen_idx = np.random.choice(len(choices), p=probs)
    return choices[chosen_idx]

def hashify_sequence(sequence):
    assert "_" not in sequence
    return "_".join([str(token) for token in sequence])

class Sampler:
    def __init__(self, piref, reward_function, device):
        self.piref = piref
        self.reward_function = reward_function
        self.device = device


    def sample(self, prefix):
        pass

    def sample_multiple_completions(self, prefix, K):
        results = []
        for i in range(K):
            results.append(self.sample(prefix))
        return results

        

def GenericJSSampler(up_prob = 1.0):
    class JSSamplerWithCompletions(Sampler):
        def forward(self, prefix, completion):
            #print(f"Current len: {len(completion)}", flush=True)
            choices = []
            probs = []
            sequence = prefix + completion
            
            # Calculate current completion step
            completion_so_far = len(completion)
            #print(f"length {completion_so_far}", flush=True)
           
            num_candidates = 32

            
            # Part 2: Forward options (if we haven't reached completion length)
            token_probs = self.piref.next_token_probs(sequence)
            #token_probs = np.array(token_probs, dtype=np.float64)
            #token_probs /= sum(token_probs)
            backtrack_weight = self.reward_function(sequence)

            #tokens = np.random.choice(self.piref.vocab_size, size=num_candidates, p=token_probs)           
            tokens = torch.multinomial(token_probs, num_samples = num_candidates, replacement=True)
            sequences = [sequence + [tokens[i]] for i in range(num_candidates)]
            weights = [self.reward_function(sequences[i]) for i in range(num_candidates)]
            if completion_so_far > 0:
                backtrack_prob = up_prob * backtrack_weight * num_candidates
                choices.append(completion[:-1])
                probs.append(backtrack_prob)
            if not self.piref.is_complete(prefix, prefix+completion):
                for i in range(num_candidates):
                    choices.append(completion + [tokens[i]])
                    probs.append(weights[i])

            if sum(probs) == 0:
                chosen_idx = -1 # default to going down
            else:
                probs = np.array(probs, dtype=np.float64)
                probs = probs / sum(probs)
                chosen_idx = np.random.choice(len(choices), p=probs)
            
            #print(len(completion))
            #print(probs)
            return choices[chosen_idx]

        def sample(self, prefix):
            """
            Sample a completion for the given prefix
            
            Args:
                min_steps: Minimum number of forward steps before accepting
                
            Returns:
                Full sequence (prefix + completion)
            """
            completion = []
            steps_used = 0

            # Continue until we reach completion length
            while (not self.piref.is_complete(prefix, prefix + completion)):
                completion = self.forward(prefix, completion)
                steps_used += 1
                if steps_used % 100 == 0:
                    print(f"used {steps_used} steps",flush=True)
            return prefix+completion, steps_used

    return JSSamplerWithCompletions

def GenericTWSampler(reset=False):
    class TWSampler(Sampler):
        def sample(self, prefix):
            """
            Forward step for Tokenwise sampling with prefix
            """
            completion = []
            steps = 0
            #total_ntp_time = 0.0
            total_rep_time = 0.0
            while not self.piref.is_complete(prefix, prefix + completion):
                steps += 1
                choices = []
                probs = []
                sequence = prefix + completion
            
                # Calculate current completion step
                #start_time = time.perf_counter()
                token_probs = self.piref.next_token_probs(sequence)
                #total_ntp_time += (time.perf_counter() - start_time)

                num_candidates = 32

                #token_probs = np.array(token_probs, dtype=np.float64)
                #token_probs /= sum(token_probs)
                #print("Entering token enumeration",flush=True)
                #tokens = np.random.choice(self.piref.vocab_size, size=32, p=token_probs)
                tokens = torch.multinomial(token_probs, num_samples = num_candidates, replacement=True)
                sequences = [sequence + [tokens[i]] for i in range(num_candidates)]
                
                start_time = time.perf_counter()
                weights = [self.reward_function(sequences[i]) for i in range(num_candidates)]
                total_rep_time += (time.perf_counter() - start_time)

                if sum(weights) == 0:
                    if reset:
                        completion = []
                    #print("Trying to extend:")
                    #print(self.piref.tokenizer.decode(sequence))
                    #print("---------------------")
                    continue
                else:
                    weights = np.array(weights, dtype=np.float64)
                    probs = weights / sum(weights)
                    chosen_idx = np.random.choice(len(tokens), p=probs)
                completion.append(tokens[chosen_idx])
                    #forward_prob = token_probs[token] * forward_weight
                    #choices.append(token)
                    #probs.append(forward_prob)
                #print("Exiting token enumeration",flush=True)
            
                # Normalize probabilities
                #print(probs)
                #probs = np.array(probs)
                #total = sum(probs)
                #if total == 0 or np.isnan(total):
                #    assert False, f"Tokenwise transition probabilies at {sequence} were {probs}"
                #    return sequence  # Stay at current sequence if no valid options
            
                #probs /= total
            
                # Sample next sequence
                #chosen_idx = np.random.choice(len(choices), p=probs)
                #completion.append(choices[chosen_idx])
            #print(f"Time spent in next_token_probs: {total_ntp_time:.4f}")
            print(f"Time spent in reward_function: {total_rep_time:.4f}")
            return prefix + completion, steps
    return TWSampler

class UnguidedLMSampler(Sampler):
    def sample(self, prefix):
        completion = []
        while not self.piref.is_complete(prefix, prefix + completion):
            sequence = prefix + completion
        
            # Calculate current completion step
            token_probs = self.piref.next_token_probs(sequence)
            #weight = self.get_weight_from_rep(rep)
            #if len(sequence) not in self.weights_by_length:
            #    self.weights_by_length[len(sequence)] = []
            #self.weights_by_length[len(sequence)].append(weight)

            #print(f"Token prob sum {sum(token_probs):10f}", flush=True)
            #token_probs = np.array(token_probs,dtype=np.float64)
            #token_probs /= sum(token_probs)
            #print(f"Token prob sum {sum(token_probs):10f}", flush=True)
            #chosen_token = np.random.choice(len(token_probs), p=token_probs)     
            chosen_token = torch.multinomial(token_probs, num_samples = 1, replacement=True)[0]
            completion.append(chosen_token)
        return prefix + completion, len(completion)
