import random

import torch
import numpy as np
from scipy.optimize import linprog
from scipy.optimize import minimize
import torch.nn.functional as F

def reachable_probability_intervals(lower: torch.Tensor,
                                       upper: torch.Tensor,
                                       eps: float = 1e-8):
    # sum_j≠i q_Uj  = total_U - U_i
    sum_upper = upper.sum(dim=-1, keepdim=True)
    sum_lower = lower.sum(dim=-1, keepdim=True)
    # q*_L_i = max(q_L_i, 1 - sum_{j≠i} q_U_j)
    lower_mod = torch.max(lower, 1.0 - (sum_upper - upper))
    # q*_U_i = min(q_U_i, 1 - sum_{j≠i} q_L_j)
    upper_mod = torch.min(upper, 1.0 - (sum_lower - lower))

    lower_mod = torch.clamp(lower_mod, min=0.0, max=1.0)
    upper_mod = torch.clamp(upper_mod, min=0.0, max=1.0)
    return lower_mod, upper_mod


def entropy(p: torch.Tensor) -> torch.Tensor:
    # p: shape [..., C], 已保证 p >= eps 且 sum(p)=1
    return -(p * torch.log(p.clamp(min=1e-8))).sum(dim=-1)


# v4
def uncertainty_by_inter(inter_head_out_p, non_image_token_indices=None):
    """
    Args:
        p_lo: Tensor[..., C], 每类别的下界 q_Li
        p_hi: Tensor[..., C], 每类别的上界 q_Ui
    Returns:
        uncertainty_score: Tensor[..., 1], 每个样本的熵不确定性
    """
    C = inter_head_out_p.shape[-1] // 2
    p_lo = inter_head_out_p[..., :C]
    p_hi = inter_head_out_p[..., C:]

    # 计算可达概率上下界
    p_lo_star, p_hi_star = reachable_probability_intervals(p_lo, p_hi)

    delta = p_hi_star - p_lo_star  # [..., C]

    if non_image_token_indices is not None:
        non_image_tokens_mask = torch.zeros(C, dtype=torch.bool, device=inter_head_out_p.device)
        non_image_tokens_mask[non_image_token_indices] = True
        delta_masked = torch.where(non_image_tokens_mask, 0.0, delta)
    else:
        delta_masked = delta

    std_w = delta_masked.std(dim=-1, unbiased=False).unsqueeze(1)
    sum_w = delta_masked.sum(dim=-1, keepdim=True).clamp(min=1e-8) 

    uncertainty_score = std_w * sum_w * 1e5

    return uncertainty_score

def compute_uncertainty(logits: torch.Tensor,
                        logits_mid: torch.Tensor = None,
                        method: str = "inter",
                        non_image_token_indices=None,
                        current_threshold=None) -> torch.Tensor:
    if method == "inter":
        return uncertainty_by_inter(logits, non_image_token_indices=non_image_token_indices)
    else:
        raise ValueError(f"Unknown uncertainty method: {method}")
