from args import get_args_parser
import os
import json
import alpha_clip
from HICE import HICEScore
from SGM.textGraphParsor import TextGraphParsor
import torch
import os
from mobilesamv2 import sam_model_registry, SamPredictor
from old_inference import create_model, batch_iterator, encoder_path



if __name__ == "__main__":
    config = get_args_parser()
    config.device = "cuda:0"
    vlm_model, preprocess = alpha_clip.load("ViT-L/14",
                                        alpha_vision_ckpt_pth="huggingface/alpha-CLIP/clip_l14_grit20m_fultune_2xe.pth",
                                        device=config.device)  # change to your own ckpt path
    sg_model = TextGraphParsor(device=config.device,
                               parser_checkpoint=config.sg_model)
    # sam
    mobilesamv2, ObjAwareModel = create_model()
    image_encoder = sam_model_registry[config.encoder_type](encoder_path[config.encoder_type])
    mobilesamv2.image_encoder = image_encoder
    device = "cuda" if torch.cuda.is_available() else "cpu"
    mobilesamv2.to(device=device)
    mobilesamv2.eval()
    predictor = SamPredictor(mobilesamv2)

    hice_scorer = HICEScore(vlm_model, preprocess, sg_model, mobilesamv2, ObjAwareModel,predictor, config.device, config)

    image_ids = [img_id for img_id in os.listdir(config.data_root)]
    with open(config.candidates_json) as f:
        candidates = json.load(f)
    candidates = [candidates[cid.split('.')[0]] for cid in image_ids]
    with open(config.references_json) as f:
        references = json.load(f)
        references = [references[cid.split('.')[0]] for cid in image_ids]
    image_paths = [os.path.join(config.data_root, img_id) for img_id in image_ids]
    free_tokens_list = []
    ref_tokens_list = []

    for image_path, cand, refer in zip(image_paths, candidates, references):
        img_id = image_path.split("/")[-1][:-4]
        print("--" * 20)
        print(f"Image: {img_id}, Caption:{cand}")
        # global_score, local_precision, local_recall = hice_scorer(image_path, [cand])
        hice_scores_results = hice_scorer.batch_compute_scores([image_path], [cand],[refer])
        print(f"hice_IT:{np.mean(total_res_dict['hice_IT'])}, "
              f"global_IT:{np.mean(total_res_dict['global_IT'])}, "
              f"local_IT:{np.mean(total_res_dict['local_IT'])}, "
              f"local_IT_precision:{np.mean(total_res_dict['local_IT_precision'])}, "
              f"local_IT_recall:{np.mean(total_res_dict['local_IT_recall'])},  "
              f"hice_TT:{np.mean(total_res_dict['hice_TT'])},  "
              f"global_TT:{np.mean(total_res_dict['global_TT'])},  "
              f"local_TT:{np.mean(total_res_dict['local_TT'])},  "
              f"refhice:{np.mean(total_res_dict['refhice'])},  "
              )


