import logging
logger = logging.getLogger()
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

from main.wmattacker import *

def single_attacks(cfgs, device, wm_path, imagename, ssim_threshold, att_pipe):
    logging.info(f'===== Init Attackers =====')
    # Attack watermark images with individual attacks
    attackers = {
        'diff_attacker_60': DiffWMAttacker(att_pipe, batch_size=5, noise_step=60, captions={}),
        'cheng2020-anchor_3': VAEWMAttacker('cheng2020-anchor', quality=3, metric='mse', device=device),
        'bmshj2018-factorized_3': VAEWMAttacker('bmshj2018-factorized', quality=3, metric='mse', device=device),
        'jpeg_attacker_50': JPEGAttacker(quality=50),
        'rotate_90': RotateAttacker(degree=90),
        'brightness_0.5': BrightnessAttacker(brightness=0.5),
        'contrast_0.5': ContrastAttacker(contrast=0.5),
        'Gaussian_noise': GaussianNoiseAttacker(std=0.05),
        'Gaussian_blur': GaussianBlurAttacker(kernel_size=5, sigma=1),
        'bm3d': BM3DAttacker(),
    }

    logging.info(f'===== Start Attacking... =====')

    post_img = os.path.join(wm_path, f"{imagename.split('.')[0]}_{cfgs['save_iters'][-1]}_SSIM{ssim_threshold}.png")
    for attacker_name, attacker in attackers.items():
        # FIXME:
        # print(f'Attacking with {attacker_name}')
        os.makedirs(os.path.join(wm_path, attacker_name), exist_ok=True)
        att_img_path = os.path.join(wm_path, attacker_name, os.path.basename(post_img))
        attackers[attacker_name].attack([post_img], [att_img_path])


def combined_attacks(cfgs, device, wm_path, imagename, ssim_threshold, att_pipe):
    # Attack watermark images with combined attacks
    case_list = ['w/ rot', 'w/o rot']

    logging.info(f'===== Init Attackers =====')
    post_img = os.path.join(wm_path, f"{imagename.split('.')[0]}_{cfgs['save_iters'][-1]}_SSIM{ssim_threshold}.png")

    for case in case_list:
        print(f'Case: {case}')
        if case == 'w/ rot':
            attackers = {
            'diff_attacker_60': DiffWMAttacker(att_pipe, batch_size=5, noise_step=60, captions={}),
            'cheng2020-anchor_3': VAEWMAttacker('cheng2020-anchor', quality=3, metric='mse', device=device),
            'bmshj2018-factorized_3': VAEWMAttacker('bmshj2018-factorized', quality=3, metric='mse', device=device),
            'jpeg_attacker_50': JPEGAttacker(quality=50),
            'rotate_90': RotateAttacker(degree=90),
            'brightness_0.5': BrightnessAttacker(brightness=0.5),
            'contrast_0.5': ContrastAttacker(contrast=0.5),
            'Gaussian_noise': GaussianNoiseAttacker(std=0.05),
            'Gaussian_blur': GaussianBlurAttacker(kernel_size=5, sigma=1),
            'bm3d': BM3DAttacker(),
            }
            multi_name = 'all'
        elif case == 'w/o rot':
            attackers = {
            'diff_attacker_60': DiffWMAttacker(att_pipe, batch_size=5, noise_step=60, captions={}),
            'cheng2020-anchor_3': VAEWMAttacker('cheng2020-anchor', quality=3, metric='mse', device=device),
            'bmshj2018-factorized_3': VAEWMAttacker('bmshj2018-factorized', quality=3, metric='mse', device=device),
            'jpeg_attacker_50': JPEGAttacker(quality=50),
            'brightness_0.5': BrightnessAttacker(brightness=0.5),
            'contrast_0.5': ContrastAttacker(contrast=0.5),
            'Gaussian_noise': GaussianNoiseAttacker(std=0.05),
            'Gaussian_blur': GaussianBlurAttacker(kernel_size=5, sigma=1),
            'bm3d': BM3DAttacker(),
            }
            multi_name = 'all_norot'
            
        os.makedirs(os.path.join(wm_path, multi_name), exist_ok=True)
        att_img_path = os.path.join(wm_path, multi_name, os.path.basename(post_img))
        for i, (attacker_name, attacker) in enumerate(attackers.items()):
            # FIXME:
            # print(f'Attacking with {attacker_name}')
            if i == 0:
                attackers[attacker_name].attack([post_img], [att_img_path], multi=True)
            else:
                attackers[attacker_name].attack([att_img_path], [att_img_path], multi=True)
