import numpy as np
import torch
from dyck import RandomWalkDyck
from lm import conditional_nn_generate
from train_values import compute_rep_from_seq, compute_pool_range
import time

def sample_with_weights(choices, probs):
    probs = np.array(probs)
    total = sum(probs)
    assert total != 0 and not np.isnan(total), \
        f"Transition probabilies at {sequence} were {probs}"
        #return sequence  # Stay at current sequence if no valid options
    
    probs /= total
    
    # Sample next sequence
    chosen_idx = np.random.choice(len(choices), p=probs)
    return choices[chosen_idx]

def hashify_sequence(sequence):
    assert "_" not in sequence
    return "_".join([str(token) for token in sequence])

class Sampler:
    def __init__(self, piref, value_function, cfg_rep, device):
        self.piref = piref
        self.value_function = value_function
        self.layer_for_vf = cfg_rep.layer_index
        self.which_tokens_for_vf = cfg_rep.tokens
        self.cache = {}
        self.weights_cache = {}
        self.weights_by_length = {}
        self.device = device
        self.reward_func = None
        self.relevant_weight = None
        self.last_weight = None

    def get_batch_weights(self, base_sequence, sequences, pool_range):
        batch_reps = self.piref.compute_batch_reps(base_sequence, sequences, self.layer_for_vf, self.which_tokens_for_vf, pool_range)
        #base_rep_tensor = torch.tensor(base_rep, dtype=torch.float32).cuda()
        #with torch.no_grad():
        #    base_f_val = self.value_function(base_rep_tensor).item()
        #base_f_val = max(0, min(base_f_val, 1))
        batch_tensor = torch.tensor(batch_reps, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            f_vals = self.value_function(batch_tensor).detach().cpu()
        f_vals = torch.clamp(f_vals, min=0, max=1)
        #for i, seq in enumerate(sequences):
        #    self.weights_cache[hashify_sequence(seq)] = f_vals[i]
        return f_vals

    def get_batch_weights_from_reps(self, batch_reps):
        with torch.no_grad():
            f_vals = self.value_function(batch_reps).detach().cpu()
        f_vals = torch.clamp(f_vals, min=0, max=1)
        return f_vals
        

    def get_weight_from_rep(self, rep):
        seq_tensor = torch.tensor(rep, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            f_val = self.value_function(seq_tensor).item()
        f_val = max(f_val, 0)
        f_val = min(f_val, 1)
        return f_val

    def compute_or_load_token_probs(self, sequence):
        s = hashify_sequence(sequence)
        if s in self.cache:
            token_probs = self.cache[s]
        else:
            token_probs = self.piref.next_token_probs(sequence)
            self.cache[s] = token_probs
        return token_probs

    def sample(self, prefix):
        pass

    def sample_multiple_completions(self, prefix, K):
        results = []
        for i in range(K):
            results.append(self.sample(prefix))
        return results

        
    def get_weight(self, seq):
        """
        Get value for sequence
        
        Args:
            seq: Full sequence (prefix + completion so far)
        """
        assert False
        s = hashify_sequence(seq)
        if s in self.weights_cache:
            return self.weights_cache[s]

        # Convert sequence to tensor for value function
        rep = compute_rep_from_seq(seq, self.piref.model, self.layer_for_vf)
        #print(encoding.shape)
        seq_tensor = torch.tensor(rep, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            f_val = self.value_function(seq_tensor).item()
        f_val = max(f_val, 0)
        f_val = min(f_val, 1)
        return f_val

def GenericBlockMJSSampler(num_candidates, block_length):
    class BlockMJSSampler(Sampler):
        def sample(self, prefix):
            completion = []
            momentum = 'down'
            num_steps = 0
            while not self.piref.is_complete(prefix, prefix+completion):
                sequence = prefix + completion

                if len(completion) == 0:
                    momentum = 'down'

                pool_range = range(len(prefix), len(sequence)) 
                _, backtrack_rep = self.piref.next_token_probs_and_rep(sequence, self.layer_for_vf, self.which_tokens_for_vf, pool_range)
                if len(completion) > 0:
                    up_flow = self.get_weight_from_rep(backtrack_rep) * num_candidates
                else:
                    up_flow = 0

                k = min(block_length, self.piref.max_new_tokens - 1 - len(completion))

                pool_range = range(len(prefix), len(sequence) + k)
                forward_candidates, forward_reps = self.piref.compute_batch_blocks_and_reps(sequence, num_candidates, k, self.layer_for_vf, self.which_tokens_for_vf, pool_range)
                down_flows = self.get_batch_weights_from_reps(forward_reps)

                signed_cross_flow = up_flow - sum(down_flows)


                choices = []
                probs = []
                if momentum == 'down':
                    for i in range(num_candidates):
                        choices.append((completion + forward_candidates[i].tolist(), 'down'))
                        probs.append(down_flows[i])
                    if signed_cross_flow > 0:
                        choices.append((completion, 'up'))
                        probs.append(signed_cross_flow)
                else:
                    choices.append((completion[:-block_length], 'up'))
                    probs.append(up_flow)
                    if signed_cross_flow < 0:
                        choices.append((completion, 'down'))
                        probs.append(-signed_cross_flow)
                completion, momentum = sample_with_weights(choices, probs)
                num_steps += 1

            return prefix + completion, num_steps
    return BlockMJSSampler


class MomentumJSSampler(Sampler):
    def forward(self, prefix, completion, momentum):
        choices = []
        probs = []
        sequence = prefix + completion
        completion_so_far = len(completion)

        if completion_so_far == 0:
            momentum = 'down'
        if completion_so_far == self.completion_length:
            assert False, "reached leaf already"
        
        token_probs = self.compute_or_load_token_probs(sequence)

        if completion_so_far > 0:
            up_flow = self.get_weight(sequence, completion_so_far)
        else:
            up_flow = 0

        down_flows = []
        for token in range(self.piref.vocab_size):
            next_seq = sequence + [token]
            # Use value function for current step
            forward_weight = self.get_weight(next_seq, completion_so_far + 1)
            forward_prob = token_probs[token] * forward_weight
            down_flows.append(forward_prob)
        
        signed_cross_flow = up_flow - sum(down_flows)

        if momentum == 'down':
            for token in range(self.piref.vocab_size):
                choices.append((completion + [token], 'down'))
                probs.append(down_flows[token])
            if signed_cross_flow > 0:
                choices.append((completion, 'up'))
                probs.append(signed_cross_flow)
        else:
            choices.append((completion[:-1], 'up'))
            probs.append(up_flow)
            if signed_cross_flow < 0:
                choices.append((completion, 'down'))
                probs.append(-signed_cross_flow)
        
        return sample_with_weights(choices, probs)

    def sample(self, prefix):
        completion = []
        momentum = 'down'
        steps_used = 0
        while len(completion) < self.completion_length:
            completion, momentum = self.forward(prefix, completion, momentum)
            steps_used += 1
        return prefix + completion, steps_used


def GenericJSSampler(up_prob = 1.0, min_steps = 0, fixed_up_weight = None, fake_layer = False):
    class JSSamplerWithCompletions(Sampler):
        def forward(self, prefix, completion):
            choices = []
            probs = []
            sequence = prefix + completion
            
            completion_so_far = len(completion)
           
            num_candidates = 32

            # Debug option (was to reproduce a bug we found); fake_layer is False for the experiments
            if fake_layer:
                token_probs, rep = self.piref.next_token_probs_and_rep(sequence, -1)
            else:
                pool_range = compute_pool_range(prefix, sequence, self.which_tokens_for_vf)
                token_probs, rep = self.piref.next_token_probs_and_rep(sequence, self.layer_for_vf, self.which_tokens_for_vf, pool_range)
            
            # Debug option; fixed_up_weight is None for the experiments
            if fixed_up_weight is not None:
                backtrack_weight = fixed_up_weight
            else:
                backtrack_weight = self.get_weight_from_rep(rep)

            # For large alphabet size, need to sample candidate tokens instead of enumerating alphabet (implicit rejection sampling)
            tokens = torch.multinomial(token_probs, num_samples = num_candidates, replacement=True)
            sequences = [sequence + [tokens[i]] for i in range(num_candidates)]
            pool_range = compute_pool_range(prefix, sequences[0], self.which_tokens_for_vf)
            weights = self.get_batch_weights(sequence, sequences, pool_range)
            
            # Option to use true rewards at last layer
            if self.reward_func is not None:
                if self.piref.is_complete(prefix, sequences[0]):
                    for i in range(num_candidates):
                        weights[i] = self.reward_func(sequences[i])
            
            # Part 1: Backtrack option (if we have completion tokens)
            if completion_so_far > 0:
                # Use value function for current step (evaluating current sequence)
                backtrack_prob = up_prob * backtrack_weight * num_candidates
                choices.append(completion[:-1])
                probs.append(backtrack_prob)
            
            # Part 2: Forward option
            if not self.piref.is_complete(prefix, prefix+completion):
                for i in range(num_candidates):
                    choices.append(completion + [tokens[i]])
                    probs.append(weights[i])

            if sum(probs) == 0:
                chosen_idx = -1 # default to going down the autoregressive tree
            else:
                probs = np.array(probs, dtype=np.float64)
                probs = probs / sum(probs)
                chosen_idx = np.random.choice(len(choices), p=probs)
            
            return choices[chosen_idx]

        def sample(self, prefix):
            completion = []
            steps_used = 0

            # Continue until we reach completion length
            while (not self.piref.is_complete(prefix, prefix + completion)) or (steps_used < min_steps):
                completion = self.forward(prefix, completion)
                steps_used += 1
                if steps_used % 100 == 0:
                    print(f"used {steps_used} steps",flush=True)
            return prefix+completion, steps_used

    return JSSamplerWithCompletions

class TokenwiseSamplerWithCompletions(Sampler):
    def _sample_forwards(self, prefix):
        """
        Forward step for Tokenwise sampling with prefix
        """
        completion = []
        #total_ntp_time = 0.0
        #total_rep_time = 0.0
        while not self.piref.is_complete(prefix, prefix + completion):
            choices = []
            probs = []
            sequence = prefix + completion
        
            # Calculate current completion step
            #start_time = time.perf_counter()
            token_probs = self.piref.next_token_probs(sequence, self.layer_for_vf)
            #total_ntp_time += (time.perf_counter() - start_time)

            num_candidates = 32

            #token_probs = np.array(token_probs, dtype=np.float64)
            #token_probs /= sum(token_probs)
            #print("Entering token enumeration",flush=True)
            #tokens = np.random.choice(self.piref.vocab_size, size=32, p=token_probs)
            tokens = torch.multinomial(token_probs, num_samples = num_candidates, replacement=True)
            sequences = [sequence + [tokens[i]] for i in range(num_candidates)]
            
            #start_time = time.perf_counter()
            pool_range = compute_pool_range(prefix, sequences[0], self.which_tokens_for_vf)
            weights = self.get_batch_weights(sequence, sequences, pool_range)
            #total_rep_time += (time.perf_counter() - start_time)

            if self.reward_func is not None:
                if self.piref.is_complete(prefix, sequences[0]):
                    for i in range(num_candidates):
                        weights[i] = self.reward_func(sequences[i])
            
            if sum(weights) == 0:
                chosen_idx = 0
            else:
                weights = np.array(weights, dtype=np.float64)
                probs = weights / sum(weights)
                chosen_idx = np.random.choice(len(tokens), p=probs)
            completion.append(tokens[chosen_idx])
                #forward_prob = token_probs[token] * forward_weight
                #choices.append(token)
                #probs.append(forward_prob)
            #print("Exiting token enumeration",flush=True)
        
            # Normalize probabilities
            #print(probs)
            #probs = np.array(probs)
            #total = sum(probs)
            #if total == 0 or np.isnan(total):
            #    assert False, f"Tokenwise transition probabilies at {sequence} were {probs}"
            #    return sequence  # Stay at current sequence if no valid options
        
            #probs /= total
        
            # Sample next sequence
            #chosen_idx = np.random.choice(len(choices), p=probs)
            #completion.append(choices[chosen_idx])
        #print(f"Time spent in next_token_probs: {total_ntp_time:.4f}")
        #print(f"Time spent in get_batch_weights: {total_rep_time:.4f}")
        return prefix + completion, len(completion)

    def sample(self, prefix):
        seq, total_steps = self._sample_forwards(prefix)
        if self.reward_func is not None:
            while self.reward_func(seq) != 1:
                print("Failed test; retrying",flush=True)
                seq, steps = self._sample_forwards(prefix)
                total_steps += steps
        return seq, total_steps
            

def GenericBlockBoNSampler(num_candidates, block_length):
    class BlockBoNSampler(Sampler):
        def sample(self, prefix):
            completion = []
            while len(completion) < self.completion_length:
                sequence = prefix + completion
                k = min(block_length, self.completion_length - len(completion))
                candidates = []
                values = []
                for i in range(num_candidates):
                    candidate = self.piref.next_k_tokens(sequence, k)
                    value = self.get_weight(sequence + candidate, len(completion + candidate))
                    candidates.append(candidate)
                    values.append(value)
                best_idx = np.argmax(values)
                completion = completion + candidates[best_idx]
            return prefix + completion, num_candidates * self.completion_length

    return BlockBoNSampler

def GenericBlockJSSampler(num_candidates, block_length):
    class BlockJSSampler(Sampler):
        def sample(self, prefix):
            completion = []
            num_steps = 0
            while not self.piref.is_complete(prefix, prefix + completion):
                sequence = prefix + completion 
               
                pool_range = range(len(prefix), len(sequence))
                _, backtrack_rep = self.piref.next_token_probs_and_rep(sequence, self.layer_for_vf, self.which_tokens_for_vf, pool_range)
                backtrack_weight = self.get_weight_from_rep(backtrack_rep) * num_candidates

                candidates = []
                weights = []

                if len(completion) >= block_length:
                    candidates.append(completion[:-block_length])
                    weights.append(backtrack_weight)
    
                k = min(block_length, self.piref.max_new_tokens - 1 - len(completion))

                pool_range = range(len(prefix), len(sequence) + k)
                forward_candidates, forward_reps = self.piref.compute_batch_blocks_and_reps(sequence, num_candidates, k, self.layer_for_vf, self.which_tokens_for_vf, pool_range)
                forward_weights = self.get_batch_weights_from_reps(forward_reps)
                
                for i in range(num_candidates):
                    candidates.append(completion + forward_candidates[i].tolist())
                    weights.append(forward_weights[i])

                probs = np.array(weights)
                total = sum(probs)
                probs /= total
                chosen_idx = np.random.choice(len(candidates), p=probs)
                completion = candidates[chosen_idx]
                num_steps += 1
                if num_steps % 100 == 0:
                    print(f"Steps used: {num_steps}")
            return prefix + completion, num_steps

    return BlockJSSampler


def GenericBlockPropSampler(num_candidates, block_length):
    class BlockPropSampler(Sampler):
        def sample(self, prefix):
            completion = []
            while not self.piref.is_complete(prefix, prefix + completion):
                sequence = prefix + completion 
                
                pool_range = range(len(prefix), len(sequence)) 
                _ = self.piref.next_token_probs_and_rep(sequence, self.layer_for_vf, self.which_tokens_for_vf, pool_range)

                k = min(block_length, self.piref.max_new_tokens - 1 - len(completion))

                pool_range = range(len(prefix), len(sequence) + k)
                candidates, reps = self.piref.compute_batch_blocks_and_reps(sequence, num_candidates, k, self.layer_for_vf, self.which_tokens_for_vf, pool_range)
                weights = self.get_batch_weights_from_reps(reps)
                
                completion = completion + sample_with_weights(candidates, weights).tolist()
            return prefix + completion, len(completion)

    return BlockPropSampler


class UnguidedLMSampler(Sampler):
    def _sample_forwards(self, prefix):
        completion = []
        weights_and_strs = []
        while not self.piref.is_complete(prefix, prefix + completion):
            sequence = prefix + completion
        
            # Calculate current completion step
            token_probs = self.piref.next_token_probs(sequence, self.layer_for_vf)
            chosen_token = torch.multinomial(token_probs, num_samples = 1, replacement=True)[0].item()
            completion.append(chosen_token)
        return prefix + completion, len(completion)
    def sample(self, prefix):
        seq, total_steps = self._sample_forwards(prefix)
        if self.reward_func is not None:
            while self.reward_func(seq) != 1:
                print("Failed test; retrying",flush=True)
                seq, steps = self._sample_forwards(prefix)
                total_steps += steps
        return seq, total_steps

class TesterUnguidedLMSampler(Sampler):
    def _sample_forwards(self, prefix):
        completion = []
        weights_and_strs = []
        self.last_weight = None
        self.relevant_weight = None
        while not self.piref.is_complete(prefix, prefix + completion):
            sequence = prefix + completion
        
            # Calculate current completion step
            mean_pool = range(len(prefix), len(sequence))
            token_probs,rep = self.piref.next_token_probs_and_rep(sequence, self.layer_for_vf, self.which_tokens_for_vf, mean_pool)
            s = self.piref.tokenizer.decode(completion)
            weight = self.get_weight_from_rep(rep)
            weights_and_strs.append((weight, s))
            if len(s)>0 and s[-1] == '\n' and len(weights_and_strs) > 1 and '\n' not in s[:-1]:
                #print("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
                #print(weights_and_strs[-2],flush=True)
                #print("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
                self.relevant_weight = weights_and_strs[-2][0]
            #if len(sequence) not in self.weights_by_length:
            #    self.weights_by_length[len(sequence)] = []
            #self.weights_by_length[len(sequence)].append(weight)

            #print(f"Token prob sum {sum(token_probs):10f}", flush=True)
            #token_probs = np.array(token_probs,dtype=np.float64)
            #token_probs /= sum(token_probs)
            #print(f"Token prob sum {sum(token_probs):10f}", flush=True)
            #chosen_token = np.random.choice(len(token_probs), p=token_probs)     
            chosen_token = torch.multinomial(token_probs, num_samples = 1, replacement=True)[0]
            completion.append(chosen_token)
        self.last_weight = weights_and_strs[-1][0]
        return prefix + completion, len(completion), self.relevant_weight, self.last_weight
    def sample(self, prefix):
        seq, total_steps, relevant_weight, last_weight = self._sample_forwards(prefix)
        if self.reward_func is not None:
            while self.reward_func(seq) != 1:
                print("Failed test; retrying",flush=True)
                seq, steps, relevant_weight, last_weight = self._sample_forwards(prefix)
                total_steps += steps
        return seq, total_steps#, relevant_weight, last_weight

class SequencePropSampler(Sampler):
    def __init__(self, piref, value_function, layer_for_vf, which_tokens_for_vf, device):
        self.piref = piref
        self.value_function = value_function
        self.layer_for_vf = layer_for_vf
        self.which_tokens_for_vf = which_tokens_for_vf
        self.device = device
        self.base_sampler = UnguidedLMSampler(piref, value_function, layer_for_vf, which_tokens_for_vf, device)
    def sample(self, prefix):
        num_candidates = 32
        num_steps = 0
        for i in range(num_candidates):
            seq, steps = self.base_sampler.sample(prefix)
            num_steps += steps
            pool_range = compute_pool_range(prefix, seq, self.which_tokens_for_vf)
            _, rep = self.piref.next_token_probs_and_rep(seq, self.layer_for_vf, self.which_tokens_for_vf, pool_range)
            weight = self.get_weight_from_rep(rep)
            print(f"Candidate {i} has value {weight}")
            if np.random.random() <= weight or i == num_candidates - 1:
                return seq, num_steps

class TokenwiseArgmaxValueSampler(Sampler):
    def sample(self, prefix):
        completion = []
        while len(completion) < self.completion_length:
            choices = []
            values = []
            sequence = prefix + completion

            completion_so_far = len(completion)
            for token in range(self.piref.vocab_size):
                next_seq = sequence + [token]
                value = self.get_weight(next_seq, completion_so_far + 1)
                choices.append(token)
                values.append(value)

            best_idx = np.argmax(values)
            completion.append(choices[best_idx])
        return prefix + completion, self.completion_length

def GenericEstimatedBestOfNSampler(base_sampler_class, best_of):
    class EstimatedBestOfNSampler(Sampler):
        def __init__(self, piref, value_functions, completion_length):
            """
            Modified Sampler for prefix-based generation
            
            Args:
                piref: LM policy reference
                value_functions: List of value functions, one for each completion step
                completion_length: Length of completion to generate (B)
            """
            self.piref = piref
            self.value_functions = value_functions
            self.completion_length = completion_length
            self.best_of = best_of
            self.base_sampler = base_sampler_class(piref, value_functions, completion_length)

        def sample(self, prefix):
            sequences = []
            estimated_values = []
            total_cost = 0
            for _ in range(self.best_of):
                sequence, cost = self.base_sampler.sample(prefix)
                estimated_value = self.get_weight(sequence, self.completion_length)
                sequences.append(sequence)
                estimated_values.append(estimated_value)
                total_cost += cost
            best_idx = np.argmax(estimated_values)
            return sequences[best_idx], total_cost
    
    return EstimatedBestOfNSampler


