# scripts/test_smoke.py
import argparse
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from data.dataset_coarse import CoarseDataset
from data.dataset_fine import FineDataset
from models.medclip import MedCLIP
from models.losses import BilingualLoss
from models.text_encoder import TextEncoder
from models.openclip_vit import build_openclip_vit
from models.region_head import RegionHead
from utils.yaml_loader import load_yaml

def test_coarse(cfg_path, sample_n=5, batch_size=2):
    cfg = load_yaml(cfg_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained(cfg['bert_path'])

    dataset = CoarseDataset(
        dataset_root=cfg['dataset_root'],
        tokenizer=tokenizer,
        max_length=cfg.get('max_text_length', 128),
        sample_n=sample_n,
        image_size=cfg.get('image_size', 336),
        max_negatives=cfg.get('max_negatives', 10)
    )
    print(f"[Coarse] samples: {len(dataset)}")
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=2, pin_memory=True)
    batch = next(iter(loader))

    # Build model
    vision_encoder = build_openclip_vit(cfg['openclip'])
    text_encoder   = TextEncoder(cfg['text_encoder'])
    region_head    = RegionHead(cfg['projection_dim'], cfg['projection_dim'], num_heads=cfg.get('region_num_heads',8))
    model = MedCLIP(vision_encoder, text_encoder, region_head, projection_dim=cfg['projection_dim']).to(device)

    # Simple forward + loss
    batch = {k: v.to(device) for k,v in batch.items() if isinstance(v, torch.Tensor)}
    img_f = model.vision_encoder(batch['image'])
    img_f = model.proj_img(img_f)
    out_en = model.text_encoder(batch['pos_ids_en'], batch['pos_mask_en'])
    txt_en_f = model.proj_txt(out_en)
    out_zh = model.text_encoder(batch['pos_ids_zh'], batch['pos_mask_zh'])
    txt_zh_f = model.proj_txt(out_zh)

    loss_fn = BilingualLoss(
        en_weight=1.0,
        zh_weight=cfg.get('bilingual_loss_weight', 0.5),
        temperature=cfg['temperature']
    )
    loss = loss_fn(img_f, txt_en_f, txt_zh_f)
    print("Coarse loss:", loss.item())
    loss.backward()
    print("Backward OK")

def test_fine(cfg_path, sample_n=5, batch_size=1):
    cfg = load_yaml(cfg_path)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained(cfg['bert_path'])

    dataset = FineDataset(
        dataset_root=cfg['dataset_root'],
        tokenizer=tokenizer,
        max_text_length=cfg.get('max_text_length', 128),
        sample_n=sample_n,
        image_size=cfg.get('image_size', 336),
        max_rois=cfg.get('max_rois', 8)
    )
    print(f"[Fine] samples: {len(dataset)}")
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=2, pin_memory=True)
    batch = next(iter(loader))

    vision_encoder = build_openclip_vit(cfg['openclip'])
    text_encoder   = TextEncoder(cfg['text_encoder'])
    region_head    = RegionHead(cfg['projection_dim'], cfg['projection_dim'], num_heads=cfg.get('region_num_heads',8))
    model = MedCLIP(vision_encoder, text_encoder, region_head, projection_dim=cfg['projection_dim']).to(device)

    batch = {k: v.to(device) for k,v in batch.items() if isinstance(v, torch.Tensor)}
    out = model(batch, losses=[
        # Simple test with global+region only
        BilingualLoss(1.0, cfg.get('bilingual_loss_weight',0.5), cfg['temperature']),
        # ShortCaptionLoss and SelfDistillLoss can also be added
    ])
    print("Fine loss:", out['bilingual_loss'])
    out['bilingual_loss'].backward()
    print("Backward OK")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=['coarse','fine'], default='coarse')
    parser.add_argument('--cfg', required=True)
    parser.add_argument('--sample_n', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=None)
    args = parser.parse_args()

    if args.mode == 'coarse':
        bs = args.batch_size or 2
        test_coarse(args.cfg, args.sample_n, bs)
    else:
        bs = args.batch_size or 1
        test_fine(args.cfg, args.sample_n, bs)