# -*- coding: UTF-8 -*-
import os.path
import numpy as np
import torch
import torch.nn as nn
import random
import json
import logging
import argparse
import datetime
from PIL import Image
from torch.backends import cudnn
from torch.nn.parallel import DataParallel as DDP
from torchvision.transforms import ToTensor

from watermark.watermark import *
from watermark.fingerprint import LimeNet, ModelDiffNet

from benchmark import ImageBenchmark

def load_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--image_size', type=int, default=224, help="length or width of images")
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--save_dir', type=str, default="./results/")
    parser.add_argument('--gpus', type=str, default='1,2')
    parser.add_argument('--fp_length', type=int, default=256)
    parser.add_argument('--target_path', type=str, default="./data/target.png")
    parser.add_argument('--base_model', type=str, default="resnet18")
    parser.add_argument('--base_dataset', type=str, default="SDog120")
    parser.add_argument('--fp_method', type=str, default="lime")
    parser.add_argument("--test_image_path", type=str, default=None)
    parser.add_argument('--num_augu', type=int, default=None)
    parser.add_argument('--lam', type=float, default=0.0)
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = load_args()
    args.save_path = os.path.join(
        args.save_dir, 
        datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    )
    # create save dir
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    # save args
    with open(os.path.join(args.save_path, "args.json"), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    # set seed
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        random.seed(args.seed)
        cudnn.deterministic = True

    # set log
    log_path = os.path.join(args.save_path, 'log.log')
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m-%d-%Y %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(log_path)
        ]
    )

    # set device
    gpus = [int(gpu) for gpu in args.gpus.split(",")]
    args.device = torch.device('cuda:{}'.format(gpus[0]) if torch.cuda.is_available() and gpus[0] != -1 else 'cpu')
    args.num_channels = 3
    
    # load model
    # craft positive-negative model set
    bench = ImageBenchmark()
    models = list(bench.list_models())
    for i, model in enumerate(models):
        if model.__str__() == "train({},{})-".format(args.base_model, args.base_dataset):
            base_model = model
    base_model = DDP(base_model.torch_model, device_ids=gpus)
    base_model.to(args.device)
    base_model.eval()

    # get target image
    target = get_target_images(args.target_path, args.fp_length, args.save_path, args.device)
    if args.fp_method == "lime":
        # load testing images
        testing_images = []
        for idx in range(1):
            image_path = os.path.join(args.test_image_path, "perturbed_image_{}.png".format(idx))
            im = Image.open(image_path)
            im = ToTensor()(im)
            im = im.to(args.device)
            testing_images.append(im)
        # construct lime model
        xai_model = LimeNet(args.num_augu, args.fp_length, args.image_size, args.num_channels, 
                                args.device, args.lam)
        explained_image_path = os.path.join(args.save_path, "explained_images/")
        if not os.path.exists(explained_image_path):
            os.makedirs(explained_image_path)
        
        ori_bers = []
        for i in range(len(testing_images)):
            with torch.no_grad():
                weights = xai_model.explain(base_model, testing_images[i])
            ber = evaluate_watermark(weights, target)
            ori_bers.append(ber)
        logger.info("Original BERs:{}".format(ori_bers))
    elif args.fp_method == "modeldiff":
        # load testing images
        testing_images = []
        ori_testing_images = []
        args.trigger_size = args.fp_length
        for idx in range(args.trigger_size):
            image_path = os.path.join(args.test_image_path, "perturbed_image_{}.png".format(idx))
            im = Image.open(image_path)
            im = ToTensor()(im)
            # im = im.to(args.device)
            testing_images.append(im)
            ori_image_path = os.path.join(args.test_image_path, "original_image_{}.png".format(idx))
            ori_im = Image.open(ori_image_path)
            ori_im = ToTensor()(ori_im)
            ori_testing_images.append(ori_im)
        images = torch.stack(testing_images).type(torch.float32)
        images = images.to(args.device)
        ori_images = torch.stack(ori_testing_images).type(torch.float32)
        ori_images = ori_images.to(args.device)
        # construct lime model
        xai_model = ModelDiffNet(args.fp_length, args.device)
        
        ori_bers = []
        with torch.no_grad():
            weights = xai_model.explain(base_model, images, ori_images)
        ber = evaluate_watermark(weights, target)
        ori_bers.append(ber)
        logger.info("Original BERs:{}".format(ori_bers))
