# -*- coding: UTF-8 -*-

import os.path

import numpy as np
import cv2 as cv
import torch
from PIL import Image


def get_target_images(target_path, wm_length, save_path, device):
    target = Image.open(target_path)
    target = target.convert("L")
    target = np.array(target)
    target = cv.resize(target, (int(np.sqrt(wm_length)), int(np.sqrt(wm_length))))
    # save target image
    target_img = target.copy()
    target_img[target_img < 200] = 0
    target_img[target_img >= 200] = 255
    target_img = Image.fromarray(target_img)
    target_img.save(os.path.join(save_path, "target.png"))
    # binaryzation
    target = target.flatten()
    target = target.astype(np.int16)
    target[target < 200] = -1
    target[target >= 200] = 1
    target = torch.from_numpy(target).long().to(device)
    return target


def bit_error_rate(source, target):
    return np.average(source != target)


def evaluate_watermark(weights, target):
    # evaluate Bit Error Rate
    bi_weights = weights.cpu().detach().numpy()
    bi_weights[bi_weights > 0] = 1
    bi_weights[bi_weights <= 0] = -1
    ber = bit_error_rate(bi_weights.squeeze().astype(np.int16), target.cpu().detach().numpy())
    
    return ber
