import numpy as np
import torch
from utils import one_hot_encode
from dyck import RandomWalkDyck
from lm import conditional_nn_generate

def detokenize_tmp(seq):
    alphabet = "([)]BEPS"
    s = ""
    for tok in seq:
        s += alphabet[tok]
    return s

def sample_with_weights(choices, probs, return_chosen_weight=False):
    probs_array = np.array(probs)
    total = sum(probs_array)
    assert total != 0 and not np.isnan(total), \
        f"Transition probabilies at {sequence} were {probs}"
        #return sequence  # Stay at current sequence if no valid options
    
    probs_array /= total
    
    # Sample next sequence
    chosen_idx = np.random.choice(len(choices), p=probs_array)
    if return_chosen_weight:
        return choices[chosen_idx], probs[chosen_idx]
    else:
        return choices[chosen_idx]

def hashify_sequence(sequence):
    assert "_" not in sequence
    return "_".join([str(token) for token in sequence])

class Sampler:
    def __init__(self, piref, value_functions, completion_length):
        """
        Modified Sampler for prefix-based generation
        
        Args:
            piref: LM policy reference
            value_functions: List of value functions, one for each completion step
            completion_length: Length of completion to generate (B)
        """
        self.piref = piref
        self.value_functions = value_functions
        self.completion_length = completion_length
        self.cache = {}
        self.reward_func = None
# Validate value functions
        assert len(value_functions.keys()) == completion_length, \
            f"Expected {completion_length} value functions, got {len(value_functions.keys())}"

    def get_weight(self, seq, step):
        """
        Get weighted value for sequence at given step
        
        Args:
            seq: Full sequence (prefix + completion so far)
            step: Current completion step (1-based, step=length of completion)
        """
        # Use the appropriate value function based on step
        vf = self.value_functions[step]
        
        # Convert sequence to tensor for value function
        encoding = one_hot_encode(seq, len(seq), self.piref.vocab_size)
        #print(encoding.shape)
        seq_tensor = torch.tensor(encoding, dtype=torch.float32).cuda()
        with torch.no_grad():
            f_val = vf(seq_tensor).item()
        
        return f_val

    def compute_or_load_token_probs(self, sequence):
        s = hashify_sequence(sequence)
        if s in self.cache:
            token_probs = self.cache[s]
        else:
            token_probs = self.piref.next_token_probs(sequence)
            self.cache[s] = token_probs
        return token_probs

    def sample(self, prefix):
        pass

    def sample_multiple_completions(self, prefix, K):
        results = []
        for i in range(K):
            results.append(self.sample(prefix))
        return results

class MomentumJSSampler(Sampler):
    def forward(self, prefix, completion, momentum):
        choices = []
        probs = []
        sequence = prefix + completion
        completion_so_far = len(completion)

        if completion_so_far == 0:
            momentum = 'down'
        if completion_so_far == self.completion_length:
            assert False, "reached leaf already"
        
        token_probs = self.compute_or_load_token_probs(sequence)

        if completion_so_far > 0:
            up_flow = self.get_weight(sequence, completion_so_far)
        else:
            up_flow = 0

        down_flows = []
        for token in range(self.piref.vocab_size):
            next_seq = sequence + [token]
            # Use value function for current step
            forward_weight = self.get_weight(next_seq, completion_so_far + 1)
            if completion_so_far+1 == self.completion_length and self.reward_func is not None:
                forward_weight = self.reward_func(next_seq)
            forward_prob = token_probs[token] * forward_weight
            down_flows.append(forward_prob.item())
        
        signed_cross_flow = up_flow - sum(down_flows)

        if momentum == 'down':
            for token in range(self.piref.vocab_size):
                choices.append((completion + [token], 'down'))
                probs.append(down_flows[token])
            if signed_cross_flow > 0:
                choices.append((completion, 'up'))
                probs.append(signed_cross_flow)
        else:
            choices.append((completion[:-1], 'up'))
            probs.append(up_flow)
            if signed_cross_flow < 0:
                choices.append((completion, 'down'))
                probs.append(-signed_cross_flow)
        
        choice, chosen_weight = sample_with_weights(choices, probs, return_chosen_weight=True)
        completion, momentum = choice
        print(detokenize_tmp(prefix+completion), chosen_weight)
        return completion, momentum

    def sample(self, prefix):
        completion = []
        momentum = 'down'
        steps_used = 0
        while len(completion) < self.completion_length:
            completion, momentum = self.forward(prefix, completion, momentum)
            steps_used += 1
        return prefix + completion, steps_used


def GenericJSSampler(up_prob = 1.0):
    class JSSamplerWithCompletions(Sampler):
        def forward(self, prefix, completion):
            choices = []
            probs = []
            sequence = prefix + completion
            
            # Calculate current completion step
            completion_so_far = len(completion)
            
            # Part 1: Backtrack option (if we have completion tokens)
            if completion_so_far > 0:
                backtrack_seq = sequence[:-1]
                # Use value function for current step (evaluating current sequence)
                backtrack_weight = self.get_weight(sequence, completion_so_far)
                backtrack_prob = up_prob * backtrack_weight
                choices.append(completion[:-1])
                probs.append(backtrack_prob)
            
            # Part 2: Forward options (if we haven't reached completion length)
            if completion_so_far < self.completion_length:
                token_probs = self.compute_or_load_token_probs(sequence)

                for token in range(self.piref.vocab_size):
                    next_seq = sequence + [token]
                    # Use value function for current step
                    forward_weight = self.get_weight(next_seq, completion_so_far + 1)
                    if completion_so_far+1 == self.completion_length and self.reward_func is not None:
                        forward_weight = self.reward_func(next_seq)
                    forward_prob = token_probs[token] * forward_weight
                    choices.append(completion + [token])
                    probs.append(forward_prob.item())
            
            # Normalize probabilities
            probs_array = np.array(probs)
            total = sum(probs_array)
            assert total != 0 and not np.isnan(total), \
                f"JS transition probabilies at {sequence} were {probs}"
                #return sequence  # Stay at current sequence if no valid options
            
            probs_array /= total
            
            # Sample next sequence
            chosen_idx = np.random.choice(len(choices), p=probs_array)
            #print(detokenize_tmp(prefix+choices[chosen_idx]), probs[chosen_idx])
            return choices[chosen_idx]

        def sample(self, prefix):
            """
            Sample a completion for the given prefix
            
            Args:
                min_steps: Minimum number of forward steps before accepting
                
            Returns:
                Full sequence (prefix + completion)
            """
            completion = []
            steps_used = 0

            # Continue until we reach completion length
            while len(completion) < self.completion_length:
                completion = self.forward(prefix, completion)
                steps_used += 1
            print(steps_used, flush=True)
            return prefix+completion, steps_used

        def forward_same_length_jobs(self, jobs):
            sequences = [job['sequence'] for job in jobs]
            batch_token_probs = self.piref.batch_next_token_probs(sequences)
            for i, job in enumerate(jobs):
                choices = []
                probs = []
                
                # Calculate current completion step
                completion_so_far = len(job['completion'])
                
                if completion_so_far > 0:
                    backtrack_seq = job['sequence'][:-1]
                    # Use value function for current step (evaluating current sequence)
                    backtrack_weight = self.get_weight(job['sequence'], completion_so_far)
                    backtrack_prob = 1.0 * backtrack_weight
                    choices.append(job['completion'][:-1])
                    probs.append(backtrack_prob)
                
                token_probs = batch_token_probs[i]

                for token in range(self.piref.vocab_size):
                    next_seq = job['sequence'] + [token]
                    # Use value function for current step
                    forward_weight = self.get_weight(next_seq, completion_so_far + 1)
                    forward_prob = token_probs[token] * forward_weight
                    choices.append(job['completion'] + [token])
                    probs.append(forward_prob)

                new_completion = sample_with_weights(choices, probs)
                job['completion'] = new_completion
                job['sequence'] = job['prefix'] + job['completion']
                job['steps'] += 1


        def sample_from_jobs(self, jobs):
            finished_jobs = []
            open_jobs = jobs
            for job in open_jobs:
                job['completion'] = []
                job['sequence'] = job['prefix'] + job['completion']
                job['steps'] = 0
            while len(open_jobs) > 0:
                print(f"Number of remaining jobs: {len(open_jobs)}",flush=True)
                jobs_by_length = {}
                for job in open_jobs:
                    job_length = len(job['sequence'])
                    if job_length not in jobs_by_length:
                        jobs_by_length[job_length] = []
                    jobs_by_length[job_length].append(job)
                # debug
                length_counts = {}
                for job_length in jobs_by_length:
                    length_counts[job_length] = len(jobs_by_length[job_length])
                print(length_counts,flush=True)
                open_jobs = []
                for _, same_length_jobs in jobs_by_length.items():
                    self.forward_same_length_jobs(same_length_jobs)
                    for job in same_length_jobs:
                        if len(job['completion']) == self.completion_length:
                            finished_jobs.append(job)
                        else:
                            open_jobs.append(job)
            return finished_jobs




    return JSSamplerWithCompletions

class TokenwiseSamplerWithCompletions(Sampler):
    def _sample_forwards(self, prefix):
        """
        Forward step for Tokenwise sampling with prefix
        """
        completion = []
        weight_list = []
        while len(completion) < self.completion_length:
            choices = []
            probs = []
            sequence = prefix + completion
        
            # Calculate current completion step
            completion_so_far = len(completion)
            token_probs = self.compute_or_load_token_probs(sequence)
            weights = []
            for token in range(self.piref.vocab_size):
                next_seq = sequence + [token]
                # Use value function for current step
                forward_weight = self.get_weight(next_seq, completion_so_far + 1)
                if completion_so_far+1 == self.completion_length and self.reward_func is not None:
                    forward_weight = self.reward_func(next_seq)
                forward_prob = token_probs[token] * forward_weight
                choices.append(token)
                probs.append(forward_prob)
                weights.append(forward_weight)
        
            # Normalize probabilities
            #print(probs)
            probs_array = np.array(probs)
            total = sum(probs_array)
            if total == 0:
                chosen_idx = 0
            else:
                probs_array /= total
                # Sample next sequence
                chosen_idx = np.random.choice(len(choices), p=probs_array)
            completion.append(choices[chosen_idx])
            #print(detokenize_tmp(prefix+completion), weights[chosen_idx])
            #weight_list.append(weights[chosen_idx])
        return prefix + completion, self.completion_length
    def sample(self, prefix):
        seq, total_steps = self._sample_forwards(prefix)
        if self.reward_func is not None:
            while self.reward_func(seq) != 1:
                #print("Failed test; retrying",flush=True)
                seq, steps = self._sample_forwards(prefix)
                total_steps += steps
        print("Succeeded",flush=True)
        return seq, total_steps

def GenericBlockBoNSampler(num_candidates, block_length):
    class BlockBoNSampler(Sampler):
        def sample(self, prefix):
            completion = []
            while len(completion) < self.completion_length:
                sequence = prefix + completion
                k = min(block_length, self.completion_length - len(completion))
                candidates = []
                values = []
                for i in range(num_candidates):
                    candidate = self.piref.next_k_tokens(sequence, k)
                    value = self.get_weight(sequence + candidate, len(completion + candidate))
                    candidates.append(candidate)
                    values.append(value)
                best_idx = np.argmax(values)
                completion = completion + candidates[best_idx]
            return prefix + completion, num_candidates * self.completion_length

    return BlockBoNSampler

def GenericBlockPropSampler(num_candidates, block_length):
    class BlockPropSampler(Sampler):
        def sample(self, prefix):
            completion = []
            while len(completion) < self.completion_length:
                sequence = prefix + completion
                k = min(block_length, self.completion_length - len(completion))
                candidates = []
                values = []
                for i in range(num_candidates):
                    candidate = self.piref.next_k_tokens(sequence, k)
                    value = self.get_weight(sequence + candidate, len(completion + candidate))
                    candidates.append(candidate)
                    values.append(value)
                probs = np.array(values)
                total = sum(probs)
                probs /= total
                chosen_idx = np.random.choice(len(candidates), p=probs)
                completion = completion + candidates[chosen_idx]
            return prefix + completion, num_candidates * self.completion_length

    return BlockPropSampler

class UnguidedLMSampler(Sampler):
    def sample(self, prefix):
        completion = []
        while len(completion) < self.completion_length:
            sequence = prefix + completion
        
            # Calculate current completion step
            token_probs = self.piref.next_token_probs(sequence)
            #print(f"Token prob sum {sum(token_probs):10f}", flush=True)
            token_probs = np.array(token_probs,dtype=np.float64)
            token_probs /= sum(token_probs)
            #print(f"Token prob sum {sum(token_probs):10f}", flush=True)
            chosen_token = np.random.choice(len(token_probs), p=token_probs)
            completion.append(chosen_token)
        return prefix + completion, self.completion_length

    def sample_multiple_completions(self, prefix, K):
        config = {'length': len(prefix) + self.completion_length - 2}
        dyck = RandomWalkDyck(config)
        if self.reward_func is None:
            sequences = conditional_nn_generate(dyck, self.piref.model, [prefix]*K, top_p=1.0, temperature=1.0) 
            return [(seq['tokens'], self.completion_length) for seq in sequences]
        else:
            correct_sequences = []
            batch_size = 100
            num_generations_required = 0
            while len(correct_sequences) < K:
                sequences = conditional_nn_generate(dyck, self.piref.model, [prefix]*batch_size, top_p=1.0, temperature=1.0) 
                for seq in sequences:
                    num_generations_required += 1
                    if self.reward_func(seq['tokens']) == 1:
                        correct_sequences.append(seq['tokens'])
                        if len(correct_sequences) == K:
                            break
            print(self.completion_length * num_generations_required / K, flush=True)
            return [(correct_sequences[i], self.completion_length * num_generations_required / K) for i in range(K)]


class TokenwiseArgmaxValueSampler(Sampler):
    def sample(self, prefix):
        completion = []
        while len(completion) < self.completion_length:
            choices = []
            values = []
            sequence = prefix + completion

            completion_so_far = len(completion)
            for token in range(self.piref.vocab_size):
                next_seq = sequence + [token]
                value = self.get_weight(next_seq, completion_so_far + 1)
                choices.append(token)
                values.append(value)

            best_idx = np.argmax(values)
            completion.append(choices[best_idx])
        return prefix + completion, self.completion_length

def GenericEstimatedBestOfNSampler(base_sampler_class, best_of):
    class EstimatedBestOfNSampler(Sampler):
        def __init__(self, piref, value_functions, completion_length):
            self.piref = piref
            self.value_functions = value_functions
            self.completion_length = completion_length
            self.best_of = best_of
            self.base_sampler = base_sampler_class(piref, value_functions, completion_length)

        def sample(self, prefix):
            sequences = []
            estimated_values = []
            total_cost = 0
            for _ in range(self.best_of):
                sequence, cost = self.base_sampler.sample(prefix)
                estimated_value = self.get_weight(sequence, self.completion_length)
                sequences.append(sequence)
                estimated_values.append(estimated_value)
                total_cost += cost
            best_idx = np.argmax(estimated_values)
            return sequences[best_idx], total_cost
    
    return EstimatedBestOfNSampler


def GenericPropSampler(base_sampler_class, num_candidates):
    class PropSampler(Sampler):
        def __init__(self, piref, value_functions, completion_length):
            self.piref = piref
            self.value_functions = value_functions
            self.completion_length = completion_length
            self.num_candidates = num_candidates
            self.base_sampler = base_sampler_class(piref, value_functions, completion_length)

        def sample(self, prefix):
            sequences = []
            estimated_values = []
            total_cost = 0
            for _ in range(self.num_candidates):
                sequence, cost = self.base_sampler.sample(prefix)
                estimated_value = self.get_weight(sequence, self.completion_length)
                sequences.append(sequence)
                estimated_values.append(estimated_value)
                total_cost += cost
            choice = sample_with_weights(sequences, estimated_values)
            return choice, total_cost
    
    return PropSampler
