import os
import torch
from torch.utils.data import DataLoader
from data.dataset_coarse import CoarseDataset
from models.medclip import MedCLIP
from models.text_encoder import TextEncoder
from models.openclip_vit import build_openclip_vit
from models.region_head import RegionHead   # Fix: Must be consistent with training
from utils.yaml_loader import load_yaml
from utils.logger import setup_logger
from evaluate.metrics import compute_retrieval, compute_bilingual_consistency
from transformers import AutoTokenizer

def extract_global_feats(model, dataloader, device):
    img_feats, txt_en_feats, txt_zh_feats = [], [], []
    for batch in dataloader:
        with torch.no_grad():
            img = batch['image'].to(device)
            # global image features
            img_f = model.vision_encoder(img)
            img_f = model.proj_img(img_f)
            # text features
            en_ids = batch['pos_ids_en'].to(device)
            en_mask= batch['pos_mask_en'].to(device)
            zh_ids = batch['pos_ids_zh'].to(device)
            zh_mask= batch['pos_mask_zh'].to(device)

            txt_en = model.text_encoder(en_ids, en_mask)
            txt_en = model.proj_txt(txt_en)
            txt_zh = model.text_encoder(zh_ids, zh_mask)
            txt_zh = model.proj_txt(txt_zh)

            img_feats.append(img_f.cpu())
            txt_en_feats.append(txt_en.cpu())
            txt_zh_feats.append(txt_zh.cpu())

    return (
        torch.cat(img_feats, dim=0),
        torch.cat(txt_en_feats, dim=0),
        torch.cat(txt_zh_feats, dim=0),
    )

def evaluate_all(cfg_path, ckpt_path, sample_n=None):
    cfg = load_yaml(cfg_path)
    logger = setup_logger(cfg.get('log_dir', 'logs/eval'))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained(cfg['bert_path'])

    # 1. Build model and load weights (region_head must be consistent with training)
    vision_encoder = build_openclip_vit(cfg['openclip'])
    text_encoder   = TextEncoder(cfg['text_encoder'])
    region_head    = RegionHead(
        in_dim=cfg['projection_dim'],
        out_dim=cfg['projection_dim'],
        num_heads=cfg.get('region_num_heads', 8)
    )
    model = MedCLIP(
        vision_encoder=vision_encoder,
        text_encoder=text_encoder,
        region_head=region_head,
        projection_dim=cfg['projection_dim']
    ).to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    model.eval()

    # 2. Global retrieval evaluation
    logger.info("==== Global retrieval on CoarseDataset ====")
    ds_coarse = CoarseDataset(
        dataset_root=cfg['dataset_root'],
        tokenizer=tokenizer,
        max_length=cfg.get('max_text_length',128),
        sample_n=sample_n,
        image_size=cfg.get('image_size',224),
        max_negatives=cfg.get('max_negatives',10)
    )
    dl_coarse = DataLoader(ds_coarse, batch_size=cfg.get('eval_batch_size',8),
                           num_workers=cfg.get('num_workers',4), pin_memory=True)
    img_g, en_g, zh_g = extract_global_feats(model, dl_coarse, device)
    ret_en = compute_retrieval(img_g, en_g)
    ret_zh = compute_retrieval(img_g, zh_g)
    bi_g  = compute_bilingual_consistency(en_g, zh_g)
    logger.info(f"[Global EN] {ret_en}")
    logger.info(f"[Global ZH] {ret_zh}")
    logger.info(f"[Global Bilingual Consistency] {bi_g}")

    return {
        "global_retrieval_en": ret_en,
        "global_retrieval_zh": ret_zh,
        "global_bilingual_consistency": bi_g
    }

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg',      required=True, help='Configuration file path')
    parser.add_argument('--ckpt',     required=True, help='Model weights path')
    parser.add_argument('--sample_n', type=int, default=None, help='Number of test samples')
    args = parser.parse_args()
    res = evaluate_all(args.cfg, args.ckpt, sample_n=args.sample_n)
    print(res)