import os
import logging
logger = logging.getLogger()
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

from main.utils import watermark_prob


def detect_watermarks(cfgs, wm_path, imagename, ssim_threshold, pipe, wm_pipe, device, threshold=0.9):
    post_img = os.path.join(wm_path, f"{imagename.split('.')[0]}_{cfgs['save_iters'][-1]}_SSIM{ssim_threshold}.png")

    attackers = ['diff_attacker_60', 'cheng2020-anchor_3', 'bmshj2018-factorized_3', 'jpeg_attacker_50', 
                'brightness_0.5', 'contrast_0.5', 'Gaussian_noise', 'Gaussian_blur', 'rotate_90', 'bm3d', 
                'all', 'all_norot']

    results_detection_dict = {}

    tester_prompt = '' # assume at the detection time, the original prompt is unknown
    text_embeddings = pipe.get_text_embedding(tester_prompt)

    logging.info(f'===== Testing the Watermarked Images {post_img} =====')
    det_prob = 1 - watermark_prob(post_img, pipe, wm_pipe, text_embeddings, device=device)
    logging.info(f'Watermark Presence Prob.: {det_prob}')

    logging.info(f'===== Testing the Attacked Watermarked Images =====')
    for attacker_name in attackers:
        if not os.path.exists(os.path.join(wm_path, attacker_name)):
            logging.info(f'Attacked images under {attacker_name} not exist.')
            continue
            
        logging.info(f'=== Attacker Name: {attacker_name} ===')
        det_prob = 1 - watermark_prob(os.path.join(wm_path, attacker_name, os.path.basename(post_img)), pipe, wm_pipe, text_embeddings, device=device)
        logging.info(f'Watermark Presence Prob.: {det_prob}')
        
        results_detection_dict[attacker_name] = (det_prob.item(), int(det_prob > threshold))
    
    return results_detection_dict