import numpy as np 
import torch
from imageio import imread
from torch import nn
import random
import argparse
import os
from PIL import Image
import torch.nn.functional as F
from math import log10
from skimage.transform import resize
import torchvision
import glob
import cv2


from decoder import fixed_watermark_decoder_d3, fixed_watermark_decoder_DADW
from utils import calculate_ssim, calculate_psnr, init_weights, logging, logger_info, mkdirs, check_loss, find_index

from distortions import jpeg_compression, gaussian_blur, median_blur
from distortions import gaussian_noise, poisson_noise, salt_and_pepper
from distortions import brightness_shifting, contrast_shifting, saturation_shifting
from distortions import cropout, resize_wm_img, rotate_wm_img


# settings
steps = 250
max_iter = 10  # iter = steps//max_iter
eta = 0.05 # step size
alpha = 0.05
eps = 0.005 # for re-embedding
num_bits = 36
img_size = 256
gpu_idx = '3'
img_dir = './imagenet/test'
# img_dir = '/data/imagenet/test'
img_dataset = img_dir.split('/')[-2]


logger_name = 'asw-'
image_save_dirs = os.path.join('results', img_dataset)
mkdirs(image_save_dirs)
logger_info(logger_name, log_path=os.path.join(image_save_dirs, logger_name + 'result' + '.log'))
logger = logging.getLogger(logger_name)
logger.info('img dataset: {:s}'.format(img_dataset))
logger.info('learning rate: {:.3f}'.format(eta))
logger.info('number of iterations: {}'.format(steps))


os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

criterion_img = torch.nn.MSELoss()
criterion_wm = torch.nn.BCEWithLogitsLoss(reduction='sum')


image_path_list = list(sorted(glob.glob(os.path.join(img_dir, '*'))))


wm_img_psnr_list = []
wm_img_ssim_list = []
wm_acc_list = [] # for identity layer

wm_jc_acc_list = [[], [], [], [], []] # for jpeg compression
wm_gb_acc_list = [[], [], [], [], []] # for gaussian_blur
wm_mb_acc_list = [[], [], [], [], []] # for median_blur

wm_gn_acc_list = [[], [], [], [], []] # for gaussian noise
wm_pn_acc_list = [[], [], [], [], []] # for poisson noise
wm_sp_acc_list = [[], [], [], [], []] # for salt&pepper noise

wm_cs_acc_list = [[], [], [], [], []] # for contrast shifting
wm_ss_acc_list = [[], [], [], [], []] # for saturation shifting
wm_bs_acc_list = [[], [], [], [], []] # for brightness shifting


wm_co_acc_list = [[], [], [], [], []] # for cropout
wm_rs_acc_list = [[], [], [], [], []] # for resize
wm_rt_acc_list = [[], [], [], [], []] # for rotation



total_running_times = []
model = fixed_watermark_decoder_d3(3, num_bits)


for i in range(len(image_path_list)): 

    if i > 50:
        break

    logger.info('*'*60)
    logger.info('watermarking {}-th image'.format(i))
    
    # load clean image
    img = imread(image_path_list[i], pilmode='RGB') / 255.0
    img = resize(img, (img_size, img_size))
    img = torch.FloatTensor(img).permute(2, 1, 0).unsqueeze(0).to(device)

    
    wm = torch.bernoulli(torch.empty(num_bits).uniform_(0, 1)).to(device) # tensor

    

    cur_psnr = 0; cur_jpeg_30_acc = 0; cur_jpeg_50_acc = 0; 

    cur_wm_list = []
    cur_jpeg_30_acc_list = []; cur_psnr_list = []
    cur_running_times = 0; search_success = True
    seed_list = []
    test_seed = []
    
    while(cur_psnr < 37):

        if cur_running_times >= 5:
            search_success = False
            break

        cur_running_times += 1
        print('running {} times for {}-th image'.format(cur_running_times, i))

        seed_for_decoder = random.randint(0, 100000000)

        logger.info('seed_for_init_decodor(receiver): {:s}'.format(str(seed_for_decoder)))
        init_weights(model, seed_for_decoder)
        model = model.to(device)
        model.eval()
        seed_list.append(seed_for_decoder)
        test_seed = seed_for_decoder


        wm_img = img.clone().detach().contiguous()
        offset_sigma = 1 # (3, 2, 2)

        if cur_running_times > 1:
            delta = eps * torch.tensor(np.random.uniform(0, 1, img.shape)).float().to(img.device)
            wm_img = wm_img + delta

        loss_history = []; loss_is_stable = False
        for j in range(steps // max_iter):
            wm_img.requires_grad = True
            optimizer = torch.optim.LBFGS([wm_img], lr=alpha, max_iter=max_iter)

            def closure():
                outputs = model(wm_img)

                loss = criterion_wm(outputs, wm) + 0.75 * criterion_img(wm_img, img)

                optimizer.zero_grad()
                loss.backward()
                return loss

            optimizer.step(closure)
            delta = torch.clamp(wm_img - img, min=-1.0, max=1.0)
            wm_img = torch.clamp(img + delta, min=0, max=1).detach().contiguous()

            outputs = model(wm_img)
            acc = len(torch.nonzero((model(wm_img)>0).float().view(-1) != wm.view(-1))) / wm.numel()
            
            loss = criterion_wm(outputs, wm) + 0.75 * criterion_img(wm_img, img)
            print(j, acc, loss.item())
            loss_history.append(loss.item())
            if j > 12:
                loss_is_stable = check_loss(loss_history)

            if acc == 0 and loss_is_stable:
                break

        # rounding and clipping operations
        wm_img = torch.round(torch.clamp(wm_img*255, min=0., max=255.))/255

        cur_wm_list.append(wm_img)
        print('cur_wm_list:', len(cur_wm_list))

        wm_img_numpy = wm_img.clone().squeeze().permute(2,1,0).detach().cpu().numpy() * 255
        img_numpy = img.clone().squeeze().permute(2,1,0).detach().cpu().numpy() * 255
        cur_psnr = calculate_psnr(img_numpy, wm_img_numpy)

        wm_image_jpeg_30 = jpeg_compression(wm_img, 30)
        cur_jpeg_30_acc = len(torch.nonzero((model(wm_image_jpeg_30)>0).float().view(-1) != wm.view(-1))) / wm.numel()

        cur_jpeg_30_acc_list.append(cur_jpeg_30_acc)
        cur_psnr_list.append(cur_psnr)
        
        print('psnr:', cur_psnr, 'cur_jpeg_30_acc:', cur_jpeg_30_acc,)
        
    
    if search_success == False:
        idx = find_index(cur_psnr_list, cur_jpeg_30_acc_list)
        wm_img = cur_wm_list[idx]
        test_seed = seed_list[idx]
        print(idx)

    total_running_times.append(cur_running_times)
    

    '''
    # --------------------------------------------
    # testing
    # --------------------------------------------
    '''

    ''' identity layer '''
    init_weights(model, test_seed)
    model = model.to(device)
    model.eval()

    wm_acc = len(torch.nonzero((model(wm_img)>0).float().view(-1) != wm.view(-1))) / wm.numel()
    resi = abs(wm_img - img) * 15
    resi_numpy = resi.squeeze().permute(2,1,0).detach().cpu().numpy() * 255
    wm_img_numpy = wm_img.clone().squeeze().permute(2,1,0).detach().cpu().numpy() * 255
    img_numpy = img.clone().squeeze().permute(2,1,0).detach().cpu().numpy() * 255
    wm_img_psnr = calculate_psnr(img_numpy, wm_img_numpy)
    wm_img_ssim = calculate_ssim(img_numpy, wm_img_numpy)
    logger.info('wm_img_psnr: {:.2f}, wm_img_ssim: {:.4f}, identity_wm_acc: {:.4f}'.format(wm_img_psnr, wm_img_ssim, wm_acc))
    wm_img_psnr_list.append(wm_img_psnr)
    wm_img_ssim_list.append(wm_img_ssim)
    wm_acc_list.append(wm_acc)
    

    ''' jpeg compression (termed as jc for short) '''
    wm_jc_acc = []
    qf_list = [70, 60, 50, 40, 30] # quality factor 

    for k in range(len(qf_list)):
        wm_image_distorted = jpeg_compression(wm_img, qf_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_jc_acc.append(wm_distorted_acc)
        wm_jc_acc_list[k].append(wm_distorted_acc)
    print('wm acc under jpeg-compression with various qf [', *qf_list, '] are [', *wm_jc_acc, ']')


    ''' gaussian_blur (termed as gb for short) '''
    wm_gb_acc = []
    ks_list = [3, 5, 7, 9, 11] # kernel size
    
    for k in range(len(ks_list)):
        wm_image_distorted = gaussian_blur(wm_img, ks_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_gb_acc.append(wm_distorted_acc)
        wm_gb_acc_list[k].append(wm_distorted_acc)
    print('wm acc under gaussian blur with various ks [', *ks_list, '] are [', *wm_gb_acc, ']')


    ''' median_blur (termed as mb for short) '''
    wm_mb_acc = []
    ks_list = [3, 5, 7, 9, 11] # kernel size
    
    for k in range(len(ks_list)):
        wm_image_distorted = median_blur(wm_img, ks_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_mb_acc.append(wm_distorted_acc)
        wm_mb_acc_list[k].append(wm_distorted_acc)
    print('wm acc under median blur with various ks [', *ks_list, '] are [', *wm_mb_acc, ']')



    ''' guassian noise (termed as gn for short) '''
    wm_gn_acc = []
    sigma_list = [10, 20, 30, 40, 50] # standard deviation of the guassian noise with fixed mean (i.e., 0)
    
    for k in range(len(sigma_list)):
        wm_image_distorted = gaussian_noise(wm_img, sigma_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_gn_acc.append(wm_distorted_acc)
        wm_gn_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under gaussian noise with various standard deviation [', *sigma_list, '] are [', *wm_gn_acc, ']')

    ''' possion noise (termed as pn for short) '''
    wm_pn_acc = []
    lam_list = [10, 20, 30, 40, 50] # kernel size
    
    for k in range(len(lam_list)):
 
        wm_image_distorted = poisson_noise(wm_img, lam_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_pn_acc.append(wm_distorted_acc)
        wm_pn_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under poisson noise with various lam [', *lam_list, '] are [', *wm_pn_acc, ']')

    ''' salt&pepper noise (termed as sp for short) '''
    wm_sp_acc = []
    sp_list = [0.02, 0.04, 0.06, 0.08, 0.10] # prop of the salt&pepper (i.e., 255&0) pixels
    
    for k in range(len(sp_list)):
        wm_image_distorted = salt_and_pepper(wm_img.clone(), sp_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_sp_acc.append(wm_distorted_acc)
        wm_sp_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under salt&pepper noise with various prop [', *sp_list, '] are [', *wm_sp_acc, ']')



    ''' brightness shifting (termed as bs for short) '''
    wm_bs_acc = []
    brightness_factor_list = [0.1, 0.2, 0.3, 0.4, 0.5] # brightness_factor
    for k in range(len(brightness_factor_list)):
        wm_image_distorted = brightness_shifting(wm_img, brightness_factor_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_bs_acc.append(wm_distorted_acc)
        wm_bs_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under brightness shifting with various distances [', *brightness_factor_list, '] are [', *wm_bs_acc, ']')


    ''' contrast shifting (termed as cs for short) '''
    wm_cs_acc = []
    contrast_factor_list = [0.1, 0.2, 0.3, 0.4, 0.5] # contrast factors
    for k in range(len(contrast_factor_list)):
        wm_image_distorted = contrast_shifting(wm_img, contrast_factor_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_cs_acc.append(wm_distorted_acc)
        wm_cs_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under contrast shifting with various distances [', *contrast_factor_list, '] are [', *wm_cs_acc, ']')


    ''' saturation shifting (termed as ss for short) '''
    wm_ss_acc = []
    saturation_factor_list = [0.1, 0.2, 0.3, 0.4, 0.5] # saturation factors
    for k in range(len(saturation_factor_list)):
        wm_image_distorted = saturation_shifting(wm_img, saturation_factor_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_ss_acc.append(wm_distorted_acc)
        wm_ss_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under saturation shifting with various distances [', *saturation_factor_list, '] are [', *wm_ss_acc, ']')


    ''' cropout (termed as co for short) '''
    wm_co_acc = []
    co_list = [0.15, 0.3, 0.45, 0.6, 0.75] # prop of drops random pixels from the noised image and substitues them with the pixels from the cover image
    for k in range(len(co_list)):
        wm_image_distorted = cropout(img, wm_img, co_list[k])
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_co_acc.append(wm_distorted_acc)
        wm_co_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under cropout with various prop [', *co_list, '] are [', *wm_co_acc, ']')


    ''' resize (termed as rs for short) '''
    wm_rs_acc = []
    rs_list = [0.9, 0.8, 0.7, 0.6, 0.5] # new sizes
    for k in range(len(rs_list)):
        wm_image_distorted = resize_wm_img(wm_img, rs_list[k], img_size)
        # print(wm_image_distorted.shape)
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_rs_acc.append(wm_distorted_acc)
        wm_rs_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under resize attack with various shape [', *rs_list, '] are [', *wm_rs_acc, ']')
    # wm_rs_acc_list = wm_rs_acc if i == 0 else [wm_rs_acc_list[i] + wm_rs_acc[i] for i in range(len(wm_rs_acc))]


    ''' rotation (termed as rt for short) '''
    wm_rt_acc = []
    rt_list = [0, 10, 15, 20, 25] # ratate angles
    for k in range(len(rt_list)):
        wm_image_distorted = rotate_wm_img(wm_img, rt_list[k])
        # print(wm_image_distorted.shape)
        wm_distorted_acc = len(torch.nonzero((model(wm_image_distorted)>0).float().view(-1) != wm.view(-1))) / wm.numel()
        wm_rt_acc.append(wm_distorted_acc)
        wm_rt_acc_list[k].append(wm_distorted_acc)
    
    print('wm acc under rotation attack with various angle [', *rt_list, '] are [', *wm_rt_acc, ']')
    # wm_rt_acc_list = wm_rt_acc if i == 0 else [wm_rt_acc_list[i] + wm_rt_acc[i] for i in range(len(wm_rt_acc))]



    # save watermarked images
    img_save_path = os.path.join(image_save_dirs, 'ori_img', image_path_list[i].split('/')[-1].split('.')[0]+'.png')
    wm_img_save_path = os.path.join(image_save_dirs, 'wm_img', image_path_list[i].split('/')[-1].split('.')[0]+'.png')
    resi_save_path = os.path.join(image_save_dirs, 'resi', image_path_list[i].split('/')[-1].split('.')[0]+'.png')
    mkdirs(os.path.join(image_save_dirs, 'ori_img'))
    mkdirs(os.path.join(image_save_dirs, 'wm_img'))
    mkdirs(os.path.join(image_save_dirs, 'resi'))
    print('saving images...')
    Image.fromarray(img_numpy.astype(np.uint8)).save(img_save_path)
    Image.fromarray(wm_img_numpy.astype(np.uint8)).save(wm_img_save_path)
    Image.fromarray(resi_numpy.astype(np.uint8)).save(resi_save_path)



'''
# --------------------------------------------
# record results
# --------------------------------------------
'''
logger.info('\n')
logger.info('Average results are as follows:')
logger.info('wm_img_psnr: {:.2f}, wm_img_ssim: {:.4f},'.format(np.array(wm_img_psnr_list).mean(), np.array(wm_img_ssim_list).mean()))
logger.info('wm_max_psnr: {:.2f}, wm_max_ssim: {:.4f},'.format(np.array(wm_img_psnr_list).max(), np.array(wm_img_ssim_list).max()))
logger.info('wm_min_psnr: {:.2f}, wm_min_ssim: {:.4f},'.format(np.array(wm_img_psnr_list).min(), np.array(wm_img_ssim_list).min()))
logger.info('wm_avg_identity_acc: {:.4f},'.format(np.array(wm_acc_list).mean()))
logger.info('wm_max_identity_acc: {:.4f},'.format(np.array(wm_acc_list).max()))
logger.info('wm_min_identity_acc: {:.4f},'.format(np.array(wm_acc_list).min()))



''' jpeg compression '''
for i in range(len(qf_list)):  
    logger.info('jpeg qf: {}, wm_avg_jpeg_acc: {:.4f}, wm_max_jpeg_acc: {:.4f}, wm_min_jpeg_acc: {:.4f},'.format(qf_list[i], np.array(wm_jc_acc_list[i]).mean(), np.array(wm_jc_acc_list[i]).max(), np.array(wm_jc_acc_list[i]).min()))
''' gaussian blur '''
for i in range(len(ks_list)):  
    logger.info('gaussian blur ks: {}, wm_avg_gb_acc: {:.4f}, wm_max_gb_acc: {:.4f}, wm_min_gb_acc: {:.4f},'.format(ks_list[i], np.array(wm_gb_acc_list[i]).mean(), np.array(wm_gb_acc_list[i]).max(), np.array(wm_gb_acc_list[i]).min()))
''' median blur '''
for i in range(len(ks_list)):  
    logger.info('median blur ks: {}, wm_avg_mb_acc: {:.4f}, wm_max_mb_acc: {:.4f}, wm_min_mb_acc: {:.4f},'.format(ks_list[i], np.array(wm_mb_acc_list[i]).mean(), np.array(wm_mb_acc_list[i]).max(), np.array(wm_mb_acc_list[i]).min()))


''' guassian noise '''
for i in range(len(sigma_list)):  
    logger.info('guassian noise sigma: {}, wm_avg_gn_acc: {:.4f}, wm_max_gn_acc: {:.4f}, wm_min_gn_acc: {:.4f},'.format(sigma_list[i], np.array(wm_gn_acc_list[i]).mean(), np.array(wm_gn_acc_list[i]).max(), np.array(wm_gn_acc_list[i]).min()))
''' possion noise '''
for i in range(len(lam_list)):  
    logger.info('possion noise lam: {}, wm_avg_pn_acc: {:.4f}, wm_max_pn_acc: {:.4f}, wm_min_pn_acc: {:.4f},'.format(lam_list[i], np.array(wm_pn_acc_list[i]).mean(), np.array(wm_pn_acc_list[i]).max(), np.array(wm_pn_acc_list[i]).min()))
''' salt&pepper noise '''
for i in range(len(sp_list)):  
    logger.info('salt&pepper noise prop: {}, wm_avg_sp_acc: {:.4f}, wm_max_sp_acc: {:.4f}, wm_min_sp_acc: {:.4f},'.format(sp_list[i], np.array(wm_sp_acc_list[i]).mean(), np.array(wm_sp_acc_list[i]).max(), np.array(wm_sp_acc_list[i]).min()))



# brightness_shifting, contrast_shifting, saturation_shifting
''' brightness '''
for i in range(len(brightness_factor_list)):  
    logger.info('brightness shifting distance: {}, wm_avg_bs_acc: {:.4f}, wm_max_bs_acc: {:.4f}, wm_min_bs_acc: {:.4f},'.format(brightness_factor_list[i], np.array(wm_bs_acc_list[i]).mean(), np.array(wm_bs_acc_list[i]).max(), np.array(wm_bs_acc_list[i]).min()))
''' contrast '''
for i in range(len(contrast_factor_list)):  
    logger.info('contrast shifting distance: {}, wm_avg_cs_acc: {:.4f}, wm_max_cs_acc: {:.4f}, wm_min_cs_acc: {:.4f},'.format(contrast_factor_list[i], np.array(wm_cs_acc_list[i]).mean(), np.array(wm_cs_acc_list[i]).max(), np.array(wm_cs_acc_list[i]).min()))
''' saturation '''
for i in range(len(saturation_factor_list)):  
    logger.info('saturation shifting distance: {}, wm_avg_ss_acc: {:.4f}, wm_max_ss_acc: {:.4f}, wm_min_ss_acc: {:.4f},'.format(saturation_factor_list[i], np.array(wm_ss_acc_list[i]).mean(), np.array(wm_ss_acc_list[i]).max(), np.array(wm_ss_acc_list[i]).min()))


''' cropout '''
for i in range(len(co_list)):
    logger.info('cropout prop: {}, wm_avg_co_acc: {:.4f}, wm_max_co_acc: {:.4f}, wm_min_co_acc: {:.4f},'.format(co_list[i], np.array(wm_co_acc_list[i]).mean(), np.array(wm_co_acc_list[i]).max(), np.array(wm_co_acc_list[i]).min()))
''' resize '''
for i in range(len(rs_list)):  
    logger.info('resize to: {}, wm_avg_rs_acc: {:.4f}, wm_max_rs_acc: {:.4f}, wm_min_rs_acc: {:.4f},'.format(rs_list[i], np.array(wm_rs_acc_list[i]).mean(), np.array(wm_rs_acc_list[i]).max(), np.array(wm_rs_acc_list[i]).min()))
''' rotate '''
for i in range(len(rt_list)):  
    logger.info('rotate angle: {}, wm_avg_rt_acc: {:.4f}, wm_max_rt_acc: {:.4f}, wm_min_rt_acc: {:.4f},'.format(rt_list[i], np.array(wm_rt_acc_list[i]).mean(), np.array(wm_rt_acc_list[i]).max(), np.array(wm_rt_acc_list[i]).min()))


logger.info('total_running_times: {}, max_running_times: {}, min_running_times: {},'.format(np.array(total_running_times).mean(), np.array(total_running_times).max(), np.array(total_running_times).min()))
