from tools_bbh import *
import math


def select_top1_path(paths, input_answers, ground_truth, task_type):
    """
    Strategy 1: Directly select the first (top-1) path.
    Returns a dictionary containing the selected index and average probability (overall).
    """
    if not input_answers or not paths:
        return {
            "index": -1,
            "right_index": [],
            "prob": 0.0,
            "answer": "[invalid]",
        }

    probs = [sum(prob.item() for _, prob in path) / len(path) if len(path) > 0 else 0 for path in paths]
    answer = extract_answer_bbh(input_answers[0], task_type)

    right_index = []
    if answer == ground_truth:
        right_index = [0]

    return {
        "index": 0,
        "right_index": right_index,
        "prob": probs[0],
        "answer": answer.strip(),
    }


def select_max_path(paths, input_answers, ground_truth, task_type):
    """
    Strategy 2: Select the path with the highest average probability.
    """
    if not input_answers or not paths:
        return {
            "index": -1,
            "right_index": [],
            "prob": 0.0,
            "answer": "[invalid]",
        }

    probs = [sum(prob.item() for _, prob in path) / len(path) if len(path) > 0 else 0 for path in paths]
    index = probs.index(max(probs))
    answer = extract_answer_bbh(input_answers[index], task_type)

    right_index = []
    for i, candidate_answer in enumerate(input_answers):
        if extract_answer_bbh(candidate_answer, task_type) == ground_truth:
            right_index.append(i)

    return {
        "index": index,
        "right_index": right_index,
        "prob": probs[index],
        "answer": answer.strip(),
    }


def select_max_path_new(paths1, paths2, input_answers, ground_truth, task_type):
    """
    Strategy 2 (Enhanced): Select the path with the highest confidence.
    Confidence is computed as normalized log(length) * average probability.
    """
    if not input_answers or not paths2:
        return {
            "index": -1,
            "right_index": [],
            "prob": 0.0,
            "answer": "[invalid]",
        }

    probs2 = [sum(prob.item() for _, prob in path) / len(path) if len(path) > 0 else 0 for path in paths2]

    right_index = []
    for i, candidate_answer in enumerate(input_answers):
        if extract_answer_bbh(candidate_answer, task_type) == ground_truth:
            right_index.append(i)

    lengths = [len(path) for path in paths1]
    log_lengths = [math.log(1 + l) for l in lengths]
    max_log = max(log_lengths) if log_lengths else 1
    normalized_lengths = [l / max_log for l in log_lengths]

    confidences = [nl * ap for nl, ap in zip(normalized_lengths, probs2)]

    index = confidences.index(max(confidences))
    answer = extract_answer_bbh(input_answers[index], task_type)

    right_index = []
    for i, candidate_answer in enumerate(input_answers):
        if extract_answer_bbh(candidate_answer, task_type) == ground_truth:
            right_index.append(i)

    return {
        "index": index,
        "right_index": right_index,
        "prob": confidences[index],
        "answer": answer.strip(),
    }

def select_open_aggregated_path(paths, input_answers, ground_truth, decay_rate=1, confidence="AVG", task_type='default',
                                 similarity_threshold=0.8):
    """
    Strategy 4 (Open Aggregation):
    1. Extract answer category from each path.
    2. Group answers by semantic similarity (threshold-based).
    3. Aggregate probabilities within each group with decay.
    4. Choose the group with highest aggregated confidence.
    5. Within the group, select the path with the highest probability.

    Args:
        paths: List of paths, each a list of (word, prob).
        input_answers: List of final answer strings per path.
        ground_truth: The correct answer.
        decay_rate: Decay rate applied to repeated answers in the same group.
        confidence: Use "AVG" for average prob, otherwise total prob.
        task_type: Task identifier passed to extract_answer_bbh.
        similarity_threshold: Threshold above which answers are grouped semantically.

    Returns:
        A dictionary with selected path index, correct indices, probability, and selected answer.
    """
    if not input_answers or not paths:
        return {
            "index": -1,
            "right_index": [],
            "prob": 0.0,
            "answer": "[invalid]",
        }

    if confidence == "AVG":
        probs = [sum(prob.item() for _, prob in path) / len(path) if path else 0 for path in paths]
    else:
        probs = [sum(prob.item() for _, prob in path) for path in paths]

    right_index = [i for i, a in enumerate(input_answers) if extract_answer_bbh(a, task_type) == ground_truth]
    answers = [extract_answer_bbh(ans, task_type) for ans in input_answers]

    group_agg = {}
    group_counts = {}
    group_indices = {}

    for i, (ans, prob) in enumerate(zip(answers, probs)):
        found_group = None
        for group_rep in group_agg.keys():
            if semantic_similarity(ans, group_rep) >= similarity_threshold:
                found_group = group_rep
                break
        if found_group is None:
            group_agg[ans] = prob
            group_counts[ans] = 1
            group_indices[ans] = [i]
        else:
            group_counts[found_group] += 1
            n = group_counts[found_group]
            decayed_prob = prob * (decay_rate ** (n - 1))
            group_agg[found_group] += decayed_prob
            group_indices[found_group].append(i)

    if not group_agg:
        return None

    best_group = max(group_agg, key=group_agg.get)
    indices = group_indices[best_group]
    if indices:
        best_idx = max(indices, key=lambda i: probs[i])
        return {
            "index": best_idx,
            "right_index": right_index,
            "prob": probs[best_idx],
            "answer": best_group.strip(),
        }
    else:
        return None


def select_aggregated_path(paths, input_answers, ground_truth, decay_rate=1, confidence="AVG", task_type='default'):
    """
    Strategy 4 (Simple Aggregation):
    1. Extract answer categories from paths.
    2. Aggregate (decayed) probability per category.
    3. Select category with highest confidence.
    4. Pick best path within that category.

    Args:
        paths: List of paths, each a list of (word, prob).
        input_answers: Final answers.
        ground_truth: Correct answer.
        decay_rate: Decay rate for repeated answers.
        confidence: "AVG" or "SUM".
        task_type: Used by answer extractor.

    Returns:
        Dict with selected path index, confidence, and answer.
    """
    if not input_answers or not paths:
        return {
            "index": -1,
            "right_index": [],
            "prob": 0.0,
            "answer": "[invalid]",
        }

    if confidence == "AVG":
        probs = [sum(prob.item() for _, prob in path) / len(path) if path else 0 for path in paths]
    else:
        probs = [sum(prob.item() for _, prob in path) for path in paths]

    right_index = [i for i, a in enumerate(input_answers) if extract_answer_bbh(a, task_type) == ground_truth]
    answers = [extract_answer_bbh(a, task_type) for a in input_answers]

    agg = {}
    count_dict = {}

    for ans, prob in zip(answers, probs):
        count_dict[ans] = count_dict.get(ans, 0) + 1
        n = count_dict[ans]
        decayed_prob = prob * (decay_rate ** (n - 1))
        agg[ans] = agg.get(ans, 0) + decayed_prob

    if not agg:
        return None
    best_ans = max(agg, key=agg.get)
    indices = [i for i, ans in enumerate(answers) if ans == best_ans]

    if indices:
        best_idx = max(indices, key=lambda i: probs[i])
        return {
            "index": best_idx,
            "right_index": right_index,
            "prob": probs[best_idx],
            "answer": best_ans.strip(),
        }
    else:
        return None


def select_open_aggregated_path_new(paths1, paths2, input_answers, ground_truth, decay_rate=1, task_type='default',
                                     similarity_threshold=0.8):
    """
    Strategy 4 (Modified - Open Aggregation with Semantic Grouping and Length Scaling):
    1. paths1 contains token lists for length computation (e.g., [(token, prob), ...]).
    2. paths2 contains token-level probabilities used for computing average prob.
    3. Normalize log(1+length) of paths1 to compute length scores.
    4. Multiply normalized length with average probability to get path confidence.
    5. Group answers by semantic similarity using a threshold.
    6. Decay confidence within each group using decay_rate.
    7. Select group with highest total confidence.
    8. Within the group, return path with highest confidence.

    Returns:
        Dict with selected path index, right index list, confidence value, and final answer.
    """
    if not input_answers or not paths1 or not paths2:
        return {
            "index": -1,
            "right_index": [],
            "confidence": 0.0,
            "answer": "[invalid]",
        }

    probs2 = [sum(prob.item() for _, prob in path) / len(path) if path else 0 for path in paths2]
    right_index = [i for i, a in enumerate(input_answers) if extract_answer_bbh(a, task_type) == ground_truth]
    answers = [extract_answer_bbh(a, task_type) for a in input_answers]

    lengths = [len(p) for p in paths1]
    total_length = sum(lengths) if lengths else 1
    normalized_lengths = [l / total_length for l in lengths]
    confidences = [nl * ap for nl, ap in zip(normalized_lengths, probs2)]

    group_agg, group_counts, group_indices = {}, {}, {}

    for i, (ans, conf) in enumerate(zip(answers, confidences)):
        found_group = None
        for rep in group_agg:
            if semantic_similarity(ans, rep) >= similarity_threshold:
                found_group = rep
                break
        if found_group is None:
            group_agg[ans] = conf
            group_counts[ans] = 1
            group_indices[ans] = [i]
        else:
            group_counts[found_group] += 1
            n = group_counts[found_group]
            group_agg[found_group] += conf * (decay_rate ** (n - 1))
            group_indices[found_group].append(i)

    if not group_agg:
        return None

    best_group = max(group_agg, key=group_agg.get)
    indices = group_indices[best_group]
    if indices:
        best_idx = max(indices, key=lambda i: confidences[i])
        return {
            "index": best_idx,
            "right_index": right_index,
            "confidence": confidences[best_idx],
            "answer": best_group.strip(),
        }
    else:
        return None


def select_aggregated_path_new(paths1, paths2, input_answers, ground_truth, decay_rate=1, task_type='default'):
    """
    Strategy 4 (Modified - No Semantic Grouping):
    1. paths1 used for length normalization.
    2. paths2 used for average probability computation.
    3. Compute confidence = normalized_length * average_prob.
    4. Group paths by exact answer string.
    5. Aggregate confidence with decay.
    6. Select highest-confidence group and best path in it.

    Returns:
        Dict with selected path index, correct indices, confidence, and final answer.
    """
    if not input_answers or not paths1 or not paths2:
        return {
            "index": -1,
            "right_index": [],
            "prob": 0.0,
            "answer": "[invalid]",
        }

    probs2 = [sum(prob.item() for _, prob in path) / len(path) if path else 0 for path in paths2]
    right_index = [i for i, a in enumerate(input_answers) if extract_answer_bbh(a, task_type) == ground_truth]
    answers = [extract_answer_bbh(a, task_type) for a in input_answers]
    lengths = [len(p) for p in paths1]
    total_length = sum(lengths) if lengths else 1
    normalized_lengths = [l / total_length for l in lengths]
    confidences = [nl * ap for nl, ap in zip(normalized_lengths, probs2)]

    agg, count_dict = {}, {}

    for ans, conf in zip(answers, confidences):
        count_dict[ans] = count_dict.get(ans, 0) + 1
        n = count_dict[ans]
        decayed_conf = conf * (decay_rate ** (n - 1))
        agg[ans] = agg.get(ans, 0) + decayed_conf

    if not agg:
        return None

    best_ans = max(agg, key=agg.get)
    indices = [i for i, a in enumerate(answers) if a == best_ans]
    if indices:
        best_idx = max(indices, key=lambda i: confidences[i])
        return {
            "index": best_idx,
            "right_index": right_index,
            "confidence": confidences[best_idx],
            "answer": best_ans.strip(),
        }
    else:
        return None


def select_top5_sum_path(paths, input_answers, ground_truth, top_n=5, task_type='default'):
    """
    Strategy X: For each path, sort token-level probabilities,
    sum the top-N values, and select the path with the highest sum.

    Args:
        paths: List of paths, each [(token_str, prob)].
        input_answers: List of final answer texts.
        ground_truth: Correct answer.
        top_n: Number of top token probabilities to sum.

    Returns:
        Dict with best path index, correct indices, summed prob, and final answer.
    """
    if not input_answers or not paths:
        return {
            "index": -1,
            "right_index": [],
            "prob": 0.0,
            "answer": "[invalid]",
        }

    def get_prob_value(prob):
        return prob.item() if hasattr(prob, "item") else prob

    right_index = [i for i, a in enumerate(input_answers) if extract_answer_bbh(a, task_type) == ground_truth]

    top_sums = []
    for token_probs in paths:
        sorted_probs = sorted(token_probs, key=lambda x: get_prob_value(x[1]), reverse=True)
        top_sum = sum(get_prob_value(x[1]) for x in sorted_probs[:top_n])
        top_sums.append(top_sum)

    best_idx = max(range(len(top_sums)), key=lambda i: top_sums[i])
    best_answer = extract_answer_bbh(input_answers[best_idx], task_type)

    return {
        "index": best_idx,
        "right_index": right_index,
        "prob": top_sums[best_idx],
        "answer": best_answer.strip(),
    }
