import os, re
from typing import List, Dict, Optional, Tuple
import torch, torch.nn.functional as F
from PIL import Image
from typing import Union
import torch
from torchvision.transforms.functional import to_pil_image
_BIO_MODEL = _BIO_PREPROC = _BIO_TOKENIZER = _DEVICE = None

def _init_biomedclip():
    global _BIO_MODEL, _BIO_PREPROC, _BIO_TOKENIZER, _DEVICE
    if _BIO_MODEL is not None:
        return _BIO_MODEL, _BIO_PREPROC, _BIO_TOKENIZER, _DEVICE
    backend = os.environ.get("BIO_MEDCLIP_BACKEND", "hf")  # "hf"|"local"
    hub_id  = os.environ.get("BIO_MEDCLIP_HUB_ID","hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
    local_dir = os.environ.get("BIO_MEDCLIP_LOCAL_DIR","checkpoints")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if backend == "hf":
        from open_clip import create_model_from_pretrained, get_tokenizer
        model, preprocess = create_model_from_pretrained(hub_id)
        tokenizer = get_tokenizer(hub_id)
    else:
        from open_clip import create_model_and_transforms, get_tokenizer
        from open_clip.factory import HF_HUB_PREFIX, _MODEL_CONFIGS
        import json
        cfg = json.load(open(os.path.join(local_dir, "open_clip_config.json")))
        model_cfg = cfg["model_cfg"]; preprocess_cfg = cfg["preprocess_cfg"]
        name = "biomedclip_local"
        if name not in _MODEL_CONFIGS: _MODEL_CONFIGS[name] = model_cfg
        tokenizer = get_tokenizer(name)
        model, _, preprocess = create_model_and_transforms(
            model_name=name, pretrained=os.path.join(local_dir,"open_clip_pytorch_model.bin"),
            **{f"image_{k}":v for k,v in preprocess_cfg.items()},
        )
    _BIO_MODEL, _BIO_PREPROC, _BIO_TOKENIZER, _DEVICE = model.to(device).eval(), preprocess, tokenizer, device
    return _BIO_MODEL, _BIO_PREPROC, _BIO_TOKENIZER, _DEVICE

def _to_pil_from_pixel_values(px: Union[torch.Tensor, Image.Image]) -> Image.Image:
    if isinstance(px, Image.Image):
        return px.convert("RGB")
    x = px.detach().float().cpu()
    if x.dim() == 4: 
        x = x[0]
    if x.dim() == 2: 
        x = x.unsqueeze(0)
    if x.shape[0] == 1: 
        x = x.repeat(3, 1, 1)
    vmin = float(x.min().item())
    vmax = float(x.max().item())
    if vmax > 1.5:  
        x = (x / 255.0).clamp(0, 1)
    elif vmin < 0.0 or vmax > 1.0:
        x = (x - vmin) / (vmax - vmin + 1e-6)
    else:  
        x = x.clamp(0, 1)

    return to_pil_image(x)

def _bbox_from_mask(mask_hw: torch.Tensor):
    if mask_hw is None: return None
    if mask_hw.dim()==3: mask_hw = mask_hw[0]
    nz = (mask_hw>0.5).nonzero(as_tuple=False)
    if nz.numel()==0: return None
    ys,xs = nz[:,0], nz[:,1]
    return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())

def _crop(img: Image.Image, xyxy):
    if xyxy is None: return img
    w,h = img.size
    x1,y1,x2,y2 = xyxy
    if 0<=x1<=1 and 0<=x2<=1 and 0<=y1<=1 and 0<=y2<=1:
        x1,x2,y1,y2 = int(x1*w),int(x2*w),int(y1*h),int(y2*h)
    x1=max(0,min(x1,w-1)); x2=max(0,min(x2,w-1))
    y1=max(0,min(y1,h-1)); y2=max(0,min(y2,h-1))
    if x2<=x1 or y2<=y1: return img
    return img.crop((x1,y1,x2,y2))

def _get_text_modules(model):
    text_enc = getattr(model, "text", None)
    tok_emb  = getattr(model, "token_embedding", None) or (getattr(text_enc, "token_embedding", None) if text_enc else None)
    pos_emb  = getattr(model, "positional_embedding", None) or (getattr(text_enc, "positional_embedding", None) if text_enc else None)
    transformer = getattr(model, "transformer", None) or (getattr(text_enc, "transformer", None) if text_enc else None)
    ln_final = getattr(model, "ln_final", None) or (getattr(text_enc, "ln_final", None) if text_enc else None)
    if tok_emb is None or pos_emb is None or transformer is None or ln_final is None:
        raise AttributeError("BiomedCLIP text submodules not found (token_embedding/positional_embedding/transformer/ln_final).")
    return tok_emb, pos_emb, transformer, ln_final

@torch.no_grad()
def compute_semantic_reward_biomedclip_batch(
    solution_strs: List[str],
    extra_infos: List[Optional[Dict]],
) -> List[float]:
    model, preprocess, tokenizer, device = _init_biomedclip()
    ctx_len = int(os.environ.get("BIO_MEDCLIP_CTX_LEN","256"))
    imgs, txts, idx_keep = [], [], []
    out = [0.0]*len(solution_strs)
    for i,(resp,info) in enumerate(zip(solution_strs, extra_infos)):
        try:
            q = (info or {}).get("question")
            if not q:
                ps = (info or {}).get("prompt_str","")
                m = re.search(r"[Qq]uestion\s*:\s*(.*)", ps)
                q = (m.group(1).strip() if m else ps) or "question"
            px = (info or {}).get("multi_modal_inputs",{}).get("pixel_values")
            if isinstance(px,(list,tuple)): px = px[0]
            if not isinstance(px, torch.Tensor): continue
            pil = _to_pil_from_pixel_values(px)
            bbox = None
            if isinstance((info or {}).get("attn_mask_hw"), torch.Tensor):
                bbox = _bbox_from_mask(info["attn_mask_hw"])
            elif isinstance((info or {}).get("evidence_mask"), torch.Tensor):
                bbox = _bbox_from_mask(info["evidence_mask"])
            elif (info or {}).get("attn_box") is not None:
                bbox = tuple(info["attn_box"])
            pil = _crop(pil, bbox)
            imgs.append(preprocess(pil).unsqueeze(0))
            txts.append(tokenizer([q], context_length=ctx_len))
            idx_keep.append(i)
        except Exception:
            pass

    if not idx_keep: return out
    img_batch = torch.cat(imgs,0).to(device)
    txt_batch = torch.cat(txts,0).to(device)
    with torch.no_grad():
        img_f, txt_f, _ = model(img_batch, txt_batch)
        img_f = F.normalize(img_f, dim=-1)
        txt_f = F.normalize(txt_f,  dim=-1)
        cos = (img_f*txt_f).sum(dim=-1)
        sc = (cos+1.0)*0.5
    for j,i in enumerate(idx_keep):
        out[i] = float(sc[j].item())
    return out

@torch.no_grad()
def encode_text_batch(texts):
    """texts: List[str] → (B,D) L2-normalized"""
    model, _, tokenizer, device = _init_biomedclip()
    ctx_len = int(os.environ.get("BIO_MEDCLIP_CTX_LEN","256"))
    toks = tokenizer(texts, context_length=ctx_len).to(device)
    model = model.eval()
    feat = model.encode_text(toks)
    feat = F.normalize(feat, dim=-1)
    return feat

# ---- keep what you have above ----
@torch.no_grad()
def encode_image_from_pixel_values(px: torch.Tensor):
    model, preprocess, _, device = _init_biomedclip()
    if not isinstance(px, torch.Tensor):
        raise TypeError("pixel_values must be a torch.Tensor")

    # 统一成 (B,C,H,W)
    if px.dim() == 3:
        px = px.unsqueeze(0)
    assert px.dim() == 4, f"encode_image_from_pixel_values expects 4D, got {px.shape}"
    B, C, H, W = px.shape
    need_preproc = True
    if (H, W) == (224, 224):
        vmin = float(px.min().item())
        vmax = float(px.max().item())
        if -5.0 < vmin < 5.0 and -5.0 < vmax < 5.0:
            need_preproc = False
    if need_preproc:
        from torchvision.transforms.functional import to_pil_image
        pil_list = []
        for b in range(B):
            img = px[b]
            if img.max() > 1.5:
                img = img / 255.0
            img = img.clamp(0, 1).cpu()
            pil_list.append(to_pil_image(img))
        proc_list = [preprocess(p).unsqueeze(0) for p in pil_list]
        x = torch.cat(proc_list, dim=0).to(device)
    else:
        x = px.to(device)

    model = model.eval()
    feat = model.encode_image(x)
    feat = F.normalize(feat, dim=-1)
    return feat


@torch.no_grad()
def image_focus_mask_from_text_and_pixels(px: torch.Tensor, text: str,
                                          grid_hw=(8,8), topk_frac=0.2):
    model, preprocess, tokenizer, device = _init_biomedclip()
    # 1) 文本向量
    z_t = encode_text_batch([text])[0:1] 
    pil = _to_pil_from_pixel_values(px) 
    pil = pil.convert("RGB")
    base = preprocess(pil).unsqueeze(0).to(device)
    z_img = encode_image_from_pixel_values(base)[0:1]
    H, W = grid_hw
    sal = torch.zeros(H, W, device=device)
    C, H0, W0 = base.shape[1:]
    ph, pw = H0 // H, W0 // W
    for r in range(H):
        for c in range(W):
            img_masked = base.clone()
            y1,y2 = r*ph, (r+1)*ph
            x1,x2 = c*pw, (c+1)*pw
            img_masked[:,:,y1:y2, x1:x2] = 0.0
            z_m = encode_image_from_pixel_values(img_masked)[0:1] 
            s = (z_img @ z_t.T) - (z_m @ z_t.T)
            sal[r,c] = s.item()
    sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-6)
    k = max(1, int(H*W*topk_frac))
    th = torch.topk(sal.flatten(), k).values.min()
    mask = (sal >= th).float()  # (H,W) 0/1
    return mask