from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
import torch
import re
from typing import Tuple, List
from preference import *
from config_bbh import *
import os
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID

# Define device: use CUDA if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_name = MODEL_DIR + '/' + RUNNING_MODEL
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

class StopOnTokensInSet(StoppingCriteria):
    def __init__(self, acceptable_ids):
        # acceptable_ids is a set of allowed token ids, e.g. [1, 28705, 13]
        self.acceptable_ids = set(acceptable_ids)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # If generated sequence is shorter than 2 tokens, cannot check last two tokens
        if input_ids.shape[1] < 2:
            return False
        if RUNNING_MODEL == "Mistral-7B":
            slice_input_ids = input_ids[0, -2:]
        elif RUNNING_MODEL == "Gemma-7B":
            slice_input_ids = input_ids[0, -1:]
        elif RUNNING_MODEL == "Meta-Llama-3.1-8B":
            slice_input_ids = input_ids[0, -1:]
        elif "Qwen2.5" in RUNNING_MODEL:
            slice_input_ids = input_ids[0, -1:]
        else:
            raise f"NO SUCH RUNNING_MODEL:{RUNNING_MODEL}"
        # Check if all tokens are in acceptable set
        if all(token.item() in self.acceptable_ids for token in slice_input_ids):
            return True
        return False

stopping_criteria = StoppingCriteriaList([StopOnTokensInSet(acceptable_tokens[RUNNING_MODEL])])

# Pre-compute large Fibonacci list
MAX_K = 100
fib_cache = [0, 1]
for _ in range(2, MAX_K+2):
    fib_cache.append(fib_cache[-1] + fib_cache[-2])

def get_fib_indices(k):
    # Return indices by subtracting 1 from cached fib values
    return [x - 1 for x in fib_cache[2:k+2]]

# Generate one token and return the logit score
def get_next_token_logit(model, tokenizer, query):
    inputs = tokenizer([query], return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    gen_out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id,do_sample=False)
    return gen_out.scores[-1]

# Generate top-k branches based on next token probabilities
def get_k_branch(model, tokenizer, query, k=5):
    logit = get_next_token_logit(model, tokenizer, query)
    k_token = logit[0].argsort()[-k:]
    k_response = []
    for token in k_token:
        new_query = query + tokenizer.decode(token)
        candidate_inputs = tokenizer(new_query, return_tensors="pt")
        candidate_inputs = {k: v.to(device) for k, v in candidate_inputs.items()}
        gen_out = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True,do_sample=False)
        k_response.append(tokenizer.decode(gen_out.sequences[0], skip_special_tokens=True))
    return k_response

# Compute token path probability
def get_token_path_prob(gen_out, num_append: int = 1):
    logits = gen_out.scores
    num_output = len(logits)
    output_ids = gen_out.sequences[0][-num_output - num_append:]
    path_prob = torch.stack([
        torch.nn.functional.softmax(score[0], dim=-1).max()
        for score in logits
    ])
    return output_ids, path_prob

# Compute top-2 token probability gap
def get_token_top2_prob_gap(gen_out, num_append: int = 1):
    logits = gen_out.scores
    num_output = len(logits)
    output_ids = gen_out.sequences[0][-num_output - num_append:]

    gap_list = []
    for score in logits:
        probs = torch.nn.functional.softmax(score[0], dim=-1)
        top2 = torch.topk(probs, k=2)
        gap = top2.values[0] - top2.values[1]
        gap_list.append(gap)

    gap_tensor = torch.stack(gap_list)
    return output_ids, gap_tensor

# Compute top-k probability gap ratio
def get_token_topk_gap_ratio(gen_out, top_k: int = 10, num_append: int = 1):
    logits = gen_out.scores
    num_output = len(logits)
    output_ids = gen_out.sequences[0][-num_output - num_append:]

    ratio_list = []
    for score in logits:
        probs = torch.nn.functional.softmax(score[0], dim=-1)
        topk = torch.topk(probs, k=top_k)
        diffs = topk.values[:-1] - topk.values[1:]
        sum_of_diffs = diffs.sum()
        if sum_of_diffs.item() == 0:
            ratio = torch.tensor(0.0, device=probs.device)
        else:
            ratio = diffs[0] / sum_of_diffs
        ratio_list.append(ratio)

    ratio_tensor = torch.stack(ratio_list)
    return output_ids, ratio_tensor

# Compute p1 / p2
def get_token_p1p2(gen_out, num_append=1):
    logits = gen_out.scores
    num_output = len(logits)
    output_ids = gen_out.sequences[0][-num_output - num_append:]

    score_list = []
    for score in logits:
        probs = torch.nn.functional.softmax(score[0], dim=-1)
        top2 = torch.topk(probs, k=2)
        p1, p2 = top2.values[0], top2.values[1]
        if p1.item() <= 1e-9 or p2.item() <= 1e-9:
            val = torch.tensor(float('inf'), device=probs.device)
        else:
            val = p1 / p2
        score_list.append(val)
    scores_tensor = torch.stack(score_list)
    return output_ids, scores_tensor

# Compute p1 / (p2 + ... + pk)
def get_token_p1_over_rest(gen_out, top_k=10, num_append=1):
    logits = gen_out.scores
    num_output = len(logits)
    output_ids = gen_out.sequences[0][-num_output - num_append:]

    score_list = []
    for score in logits:
        probs = torch.nn.functional.softmax(score[0], dim=-1)
        topk = torch.topk(probs, k=top_k)
        pvals = topk.values
        if pvals.size(0) < 2:
            val = torch.tensor(0.0, device=probs.device)
        else:
            numerator = pvals[0]
            denominator = pvals[1:].sum()
            if denominator.item() <= 1e-9:
                val = torch.tensor(float('inf'), device=probs.device)
            else:
                val = numerator / denominator
        score_list.append(val)
    scores_tensor = torch.stack(score_list)
    return output_ids, scores_tensor

# Compute top-k coverage: sum of top-k probabilities
def get_token_topk_coverage(gen_out, top_k=10, num_append=1):
    logits = gen_out.scores
    num_output = len(logits)
    output_ids = gen_out.sequences[0][-num_output - num_append:]

    score_list = []
    for score in logits:
        probs = torch.nn.functional.softmax(score[0], dim=-1)
        topk = torch.topk(probs, k=top_k)
        coverage = topk.values.sum()
        score_list.append(coverage)
    scores_tensor = torch.stack(score_list)
    return output_ids, scores_tensor

# Compute entropy for each decoding step
def get_token_entropy(gen_out, num_append=1):
    logits = gen_out.scores
    num_output = len(logits)
    output_ids = gen_out.sequences[0][-num_output - num_append:]

    entropy_list = []
    for score in logits:
        probs = torch.nn.functional.softmax(score[0], dim=-1)
        entropy = -(probs * probs.log()).sum()
        entropy_list.append(entropy)
    scores_tensor = torch.stack(entropy_list)
    return output_ids, scores_tensor

# Convert token-level probabilities into word-level
def get_path_prob(gen_out, init_token_prob=None):
    if init_token_prob is None:
        token_ids, probs = get_token_path_prob(gen_out, num_append=0)
    else:
        token_ids, probs = get_token_path_prob(gen_out)
        probs = torch.concat([init_token_prob, probs])
    current_n_words = 0
    word_probs = []
    ids = []
    current_n_tokens = 0
    word_prob = 0
    for token_id, prob in zip(token_ids, probs):
        ids.append(token_id)
        decode_seq = tokenizer.decode(ids)
        words = re.split(r' |\n|\.\|:', decode_seq)
        word = words[-1]
        if len(words) == current_n_words:
            word_prob += prob
            current_n_tokens += 1
            word_probs[-1] = (word, word_prob / current_n_tokens)
        elif len(words) > current_n_words:
            word_prob = prob
            current_n_tokens = 1
            word_probs.append((word, word_prob / current_n_tokens))
            current_n_words += 1
    return word_probs


# Get average top-2 gap for each word in output
def get_path_top2_gap(gen_out, init_token_gap=None):
    """
    Compute the average top2 probability gap for each word in the generated output.

    It uses get_token_top2_prob_gap to obtain token-level gap values, then decodes
    the token sequence step by step and splits it into words (using space, newline, period, or colon as separators).
    For each word, it averages the gap values of the tokens that constitute that word.
    """
    if init_token_gap is None:
        token_ids, probs = get_token_top2_prob_gap(gen_out, num_append=0)
    else:
        token_ids, probs = get_token_top2_prob_gap(gen_out)
        probs = torch.concat([init_token_gap, probs])
    current_n_words = 0
    word_probs = []
    ids = []
    current_n_tokens = 0
    word_prob = 0
    for token_id, prob in zip(token_ids, probs):
        ids.append(token_id)
        decode_seq = tokenizer.decode(ids)
        words = re.split(r' |\n|\.\|:', decode_seq)
        word = words[-1]
        if len(words) == current_n_words:
            word_prob += prob
            current_n_tokens += 1
            word_probs[-1] = (word, word_prob / current_n_tokens)
        elif len(words) > current_n_words:
            word_prob = prob
            current_n_tokens = 1
            word_probs.append((word, word_prob / current_n_tokens))
            current_n_words += 1
    return word_probs

# Generate k candidate outputs and return word-level path probabilities
def get_k_path_prob(model, tokenizer, query, k, max_new_tokens=250):
    logit = get_next_token_logit(model, tokenizer, query)
    k_token = logit[0].argsort()[-k:]
    k_prob = torch.nn.functional.softmax(logit[0][logit[0].argsort()[-k:]], dim=0)
    k_response = []
    for token in k_token:
        new_query = query + tokenizer.decode(token)
        candidate_inputs = tokenizer(new_query, return_tensors="pt")
        candidate_inputs = {k: v.to(device) for k, v in candidate_inputs.items()}
        gen_out = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id,do_sample=False)
        path_probs = get_path_prob(gen_out, k_prob)
        k_response.append(path_probs)
    return k_response

# Follow-up generation with appended template, for answer extraction
def get_follow_up_output(model, tokenizer, follow_up_template, gen_out, max_new_tokens=50):
    construct_input = lambda new_ids: {'input_ids': new_ids, "attention_mask": torch.ones_like(new_ids)}
    output_ids = gen_out.sequences
    follow_up_ids = tokenizer(follow_up_template, return_tensors="pt")['input_ids'].to(device)
    new_ids = torch.cat([output_ids, follow_up_ids], axis=1)
    inputs = construct_input(new_ids)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    gen_out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True,
                             max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id,do_sample=False,stopping_criteria=stopping_criteria)
    return gen_out

# Select branch index from greedy path probabilities
def find_branch_step_greedy(chosen_probs, branch_method="highest"):
    """
    Given a list of token probabilities (chosen_probs) from a greedy path, select the branch index.
    """
    if branch_method == "highest":
        return int(torch.tensor(chosen_probs).argmax().item())
    elif branch_method == "lowest":
        return int(torch.tensor(chosen_probs).argmin().item())
    elif branch_method == "local_max":
        for i in range(1, len(chosen_probs) - 1):
            if chosen_probs[i] > chosen_probs[i - 1] and chosen_probs[i] >= chosen_probs[i + 1]:
                return i
        return len(chosen_probs) - 1
    elif branch_method == "local_min":
        for i in range(1, len(chosen_probs) - 1):
            if chosen_probs[i] < chosen_probs[i - 1] and chosen_probs[i] <= chosen_probs[i + 1]:
                return i
        return len(chosen_probs) - 1
    else:
        return 0

# Compute path probabilities using specified scoring method
def get_path_prob_by_measure(gen_out, measure_method="logits"):
    if measure_method == "gap":
        return get_token_top2_prob_gap(gen_out)
    elif measure_method == "gap_ratio":
        return get_token_topk_gap_ratio(gen_out, top_k=5)
    elif measure_method == "p1p2":
        return get_token_p1p2(gen_out)
    elif measure_method == "p1_over_rest":
        return get_token_p1_over_rest(gen_out, top_k=5)
    elif measure_method == "topk_coverage":
        return get_token_topk_coverage(gen_out, top_k=5)
    elif measure_method == "entropy":
        return get_token_entropy(gen_out)
    else:
        return get_token_path_prob(gen_out)



# Main function for token-level top-k path generation and probability extraction
def get_token_k_path_prob(model, tokenizer, query, k, max_new_tokens=250, measure_method="logits", sample_method="sequential", branch_method="first_token"):
    # Step 1: Get the greedy decoding output and compute path probabilities
    candidate_inputs = tokenizer(query, return_tensors="pt")
    candidate_inputs = {k: v.to(device) for k, v in candidate_inputs.items()}
    gen_out = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True,
                             max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, do_sample=False)
    gen_output_ids, gen_path_prob = get_path_prob_by_measure(gen_out, measure_method=measure_method)
    branch_index = find_branch_step_greedy(gen_path_prob, branch_method=branch_method)

    # Compute the prefix until the branch point
    branch_prefix = tokenizer.decode(gen_output_ids[:branch_index+1])

    # Get next-token logits based on query + branch_prefix
    logit = get_next_token_logit(model, tokenizer, query + branch_prefix)

    # Select top-k tokens according to the sampling strategy
    if sample_method == "fibonacci":
        sorted_indices = logit[0].argsort(descending=True)
        fib_idx_list = get_fib_indices(k)
        k_token = torch.tensor([sorted_indices[idx]
                                for idx in fib_idx_list
                                if 0 <= idx < len(sorted_indices)])
    else:
        k_token = logit[0].argsort(descending=True)[:k]

    # Generate output for each top-k token and compute path probabilities
    k_response = []
    k_gen_path_probs = []
    for index, token in enumerate(k_token):
        new_query = query + branch_prefix + tokenizer.decode(token)
        candidate_inputs = tokenizer(new_query, return_tensors="pt")
        candidate_inputs = {k: v.to(device) for k, v in candidate_inputs.items()}
        gen_out = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, do_sample=False)
        full_output = tokenizer.decode(gen_out.sequences[0], skip_special_tokens=True).strip()
        gen_output_ids, gen_path_prob = get_path_prob_by_measure(gen_out, measure_method=measure_method)
        gen_token_probs = [(tokenizer.decode([token_id]), prob) for token_id, prob in zip(gen_output_ids, gen_path_prob)]
        k_response.append(full_output)
        k_gen_path_probs.append(gen_token_probs)
    return k_response, k_gen_path_probs


# Same as above but with follow-up generation template appended
def get_token_k_path_prob_follow_up(model, tokenizer, query, k, max_new_tokens=250,
                              follow_up_template="\nSo the final answer is: ", measure_method="logits", sample_method="sequential", branch_method="first_token"):
    candidate_inputs = tokenizer(query, return_tensors="pt")
    candidate_inputs = {k: v.to(device) for k, v in candidate_inputs.items()}
    gen_out = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True,
                             max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, do_sample=False)
    gen_output_ids, gen_path_prob = get_path_prob_by_measure(gen_out, measure_method=measure_method)
    branch_index = find_branch_step_greedy(gen_path_prob, branch_method=branch_method)

    branch_prefix = tokenizer.decode(gen_output_ids[:branch_index+1])
    logit = get_next_token_logit(model, tokenizer, query + branch_prefix)

    if sample_method == "fibonacci":
        sorted_indices = logit[0].argsort(descending=True)
        fib_idx_list = get_fib_indices(k)
        k_token = torch.tensor([sorted_indices[idx]
                                for idx in fib_idx_list
                                if 0 <= idx < len(sorted_indices)])
    else:
        k_token = logit[0].argsort(descending=True)[:k]

    k_response = []
    k_gen_path_probs = []
    k_follow_path_probs = []
    for index, token in enumerate(k_token):
        new_query = query + branch_prefix + tokenizer.decode(token)
        candidate_inputs = tokenizer(new_query, return_tensors="pt")
        candidate_inputs = {k: v.to(device) for k, v in candidate_inputs.items()}
        gen_out = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, do_sample=False)
        follow_up_out = get_follow_up_output(model, tokenizer, follow_up_template, gen_out)
        full_output = tokenizer.decode(follow_up_out.sequences[0], skip_special_tokens=True).strip()

        gen_output_ids, gen_path_prob = get_path_prob_by_measure(gen_out, measure_method=measure_method)
        follow_output_ids, follow_path_prob = get_path_prob_by_measure(follow_up_out, measure_method=measure_method)
        gen_token_probs = [(tokenizer.decode([token_id]), prob) for token_id, prob in zip(gen_output_ids, gen_path_prob)]
        follow_token_probs = [(tokenizer.decode([token_id]), prob) for token_id, prob in zip(follow_output_ids, follow_path_prob)]

        k_response.append(full_output)
        k_gen_path_probs.append(gen_token_probs)
        k_follow_path_probs.append(follow_token_probs)
    return k_response, k_gen_path_probs, k_follow_path_probs

# Two-level token branching with follow-up generation and path probability analysis
def get_token_k_j_path_prob_follow_up(model, tokenizer, query, k, j, max_new_tokens=250,
                                      follow_up_template="\nSo the final answer is: ",
                                      measure_method="logits", sample_method="sequential", lowest_prob=0.2, branch_method="lowest"):
    # Step 1: Generate initial output and determine the first-level branching prefix
    candidate_inputs = tokenizer(query, return_tensors="pt")
    candidate_inputs = {key: val.to(device) for key, val in candidate_inputs.items()}
    gen_out = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True,
                             max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id,
                             do_sample=False, repetition_penalty=1.2)
    gen_output_ids, gen_path_prob = get_path_prob_by_measure(gen_out, measure_method=measure_method)

    # First-level branching uses the first token (index 0) as branch point
    branch_index_1 = 0
    branch_prefix_1 = tokenizer.decode(gen_output_ids[:branch_index_1 + 1])

    # Compute logits for next token after query + branch_prefix_1
    logit_first = get_next_token_logit(model, tokenizer, query + branch_prefix_1)

    # Select top-k tokens for first-level branching
    if sample_method == "fibonacci":
        sorted_indices = logit_first[0].argsort(descending=True)
        fib_idx_list = get_fib_indices(k)
        k_tokens = torch.tensor([sorted_indices[idx] for idx in fib_idx_list if 0 <= idx < len(sorted_indices)])
    else:
        k_tokens = logit_first[0].argsort(descending=True)[:k]

    responses = []
    gen_path_probs_list = []
    follow_path_probs_list = []

    # Iterate over top-k tokens for first-level branching
    for token in k_tokens:
        first_branch_token = tokenizer.decode(token)
        new_query_level1 = query + branch_prefix_1 + first_branch_token

        candidate_inputs = tokenizer(new_query_level1, return_tensors="pt")
        candidate_inputs = {key: val.to(device) for key, val in candidate_inputs.items()}
        gen_out_level1 = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True,
                                        max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id,
                                        do_sample=False, repetition_penalty=1.2)

        follow_up_out = get_follow_up_output(model, tokenizer, follow_up_template, gen_out_level1)
        full_output = tokenizer.decode(follow_up_out.sequences[0], skip_special_tokens=True).strip()

        # Collect token-level path probabilities
        gen_output_ids_level1, gen_path_prob_level1 = get_path_prob_by_measure(gen_out_level1, measure_method=measure_method)
        follow_output_ids, follow_path_prob = get_path_prob_by_measure(follow_up_out, measure_method=measure_method)
        gen_token_probs = [(tokenizer.decode([tid]), prob) for tid, prob in zip(gen_output_ids_level1, gen_path_prob_level1)]
        follow_token_probs = [(tokenizer.decode([tid]), prob) for tid, prob in zip(follow_output_ids, follow_path_prob)]

        responses.append(full_output)
        gen_path_probs_list.append(gen_token_probs)
        follow_path_probs_list.append(follow_token_probs)

        # If only one token was generated, skip second-level branching
        if len(gen_path_prob_level1) <= 1:
            continue

        # Use specified method (e.g. "lowest") to determine second-level branching point
        branch_index_2 = find_branch_step_greedy(gen_path_prob_level1[1:], branch_method=branch_method)
        if gen_path_prob_level1[branch_index_2 + 1] >= lowest_prob or branch_index_2 == 0:
            continue  # Skip second-level branching if lowest prob is too high or branch would overlap

        branch_prefix_2 = tokenizer.decode(gen_output_ids_level1[:branch_index_2])

        logit_second = get_next_token_logit(model, tokenizer, query + branch_prefix_2)

        # Select top-j tokens for second-level branching
        j_tokens = logit_second[0].argsort(descending=True)[:j]

        # Generate output for each second-level token (excluding the 1st one as it's already covered)
        for token2 in j_tokens[1:]:
            second_branch_token = tokenizer.decode(token2)
            new_query_level2 = query + branch_prefix_2 + second_branch_token

            candidate_inputs = tokenizer(new_query_level2, return_tensors="pt")
            candidate_inputs = {key: val.to(device) for key, val in candidate_inputs.items()}
            gen_out_level2 = model.generate(**candidate_inputs, output_scores=True, return_dict_in_generate=True,
                                            max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id,
                                            do_sample=False, repetition_penalty=1.2)
            follow_up_out = get_follow_up_output(model, tokenizer, follow_up_template, gen_out_level2)
            full_output = tokenizer.decode(follow_up_out.sequences[0], skip_special_tokens=True).strip()

            gen_output_ids_level2, gen_path_prob_level2 = get_path_prob_by_measure(gen_out_level2, measure_method=measure_method)
            follow_output_ids, follow_path_prob = get_path_prob_by_measure(follow_up_out, measure_method=measure_method)
            gen_token_probs = [(tokenizer.decode([tid]), prob) for tid, prob in zip(gen_output_ids_level2, gen_path_prob_level2)]
            follow_token_probs = [(tokenizer.decode([tid]), prob) for tid, prob in zip(follow_output_ids, follow_path_prob)]

            responses.append(full_output)
            gen_path_probs_list.append(gen_token_probs)
            follow_path_probs_list.append(follow_token_probs)

    return responses, gen_path_probs_list, follow_path_probs_list
