import argparse
import yaml
import os
import logging
logger = logging.getLogger()
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from diffusers import DDIMScheduler
from datasets import load_from_disk

from main.wmdiffusion import WMDetectStableDiffusionPipeline
from main.wmpatch_cached import GTWatermarkCached
from main.wmpach_cached_learned_wm import GTWatermarkCachedLearnedWM
from main.attdiffusion import ReSDPipeline
from main.utils import save_img, compute_auc_tpr
from loss import pytorch_ssim

from detect.adaptive_enhancement import adaptive_enhancement, compute_quality_metrics
from detect.attacks import single_attacks, combined_attacks
from detect.detect import detect_watermarks

from main.nf_flow_models import *
from main.dataset import *
from data.data_loader_functions import *

from PIL import Image
import numpy as np

torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark = False


def init_pipelines(cfgs, args, device):
    # sd pipeline
    scheduler = DDIMScheduler.from_pretrained(cfgs['model_id'], subfolder="scheduler")
    pipe = WMDetectStableDiffusionPipeline.from_pretrained(cfgs['model_id'], scheduler=scheduler).to(device)
    pipe.set_progress_bar_config(disable=True)
    # attack pipe
    att_pipe = ReSDPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, revision="fp16").to(device)
    att_pipe.set_progress_bar_config(disable=True)
    
    return pipe, att_pipe

def eval(args):
    device = torch.device(args.device)

    logging.info(f'===== Load Config =====')
    with open(args.cfg_path, 'r') as file:
        cfgs = yaml.safe_load(file)
    logging.info(cfgs)

    logging.info(f'===== Init Pipeline =====')
    pipe, att_pipe = init_pipelines(cfgs, args, device)

    logging.info(f'===== Load Dataset =====')
    if args.dataset == 'all':
        dataloader = create_dataloder_all(args, is_train=False)
    else:
        dataloader = create_dataloder(args, is_train=False)

    logging.info(f'===== Eval Pipeline =====')

    # cumulative results
    avg_psnr = 0
    avg_ssim = 0
    avg_det_prob = 0
    avg_lpips = 0
    avg_wm_classification = 0
    avg_l1 = 0

    avg_psnr_base = 0
    avg_ssim_base = 0
    avg_det_prob_base = 0
    avg_lpips_base = 0
    avg_wm_classification_base = 0
    avg_l1_base = 0
    attackers = ['diff_attacker_60', 'cheng2020-anchor_3', 'bmshj2018-factorized_3', 'jpeg_attacker_50', 
                'brightness_0.5', 'contrast_0.5', 'Gaussian_noise', 'Gaussian_blur', 'rotate_90', 'bm3d', 
                'all', 'all_norot']
    avg_attack_det_metrics = {attack_type: 0 for attack_type in attackers}
    avg_attack_wm_metrics = {attack_type: 0 for attack_type in attackers}

    avg_attack_det_metrics_base = {attack_type: 0 for attack_type in attackers}
    avg_attack_wm_metrics_base = {attack_type: 0 for attack_type in attackers}

    list_of_all_attacks = {attack_type: [] for attack_type in attackers}
    list_of_all_attacks['no_attack'] = []

    list_of_all_attacks_base = {attack_type: [] for attack_type in attackers}
    list_of_all_attacks_base['no_attack'] = [] 

    if args.method == "residual":
        invertible_map_real = create_invertible_residual_basic(args.layers).to(device)
        invertible_map_real.forward((torch.randn(4, 64, 64).to(device)))
        invertible_map_real.load(args.save_model_dir_real)

        invertible_map_imag = create_invertible_residual_basic(args.layers).to(device)
        invertible_map_imag.forward((torch.randn(4, 64, 64).to(device)))
        invertible_map_imag.load(args.save_model_dir_imag)
        invertible_map_real.eval()
        invertible_map_imag.eval()
    else:
        raise NotImplementedError()

    prompt = [''] * args.batch_size
    wm_path = args.wm_path

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((512, 512))
    ])

    ssim_threshold = args.ssim_threshold
    wm_pipe = None
    with torch.no_grad():
        for img_idx, batch in enumerate(dataloader):
            if img_idx < args.starting_image:
                continue

            if img_idx == args.num_images:
                break
            imagename = f'{img_idx}.png'
            wm_img_path = os.path.join(wm_path, f"{imagename.split('.')[0]}_{cfgs['save_iters'][-1]}.png")
            enhanced_img_path = os.path.join(wm_path, f"{os.path.basename(wm_img_path).split('.')[0]}_SSIM{ssim_threshold}.png")
            original_img_path = os.path.join(wm_path, f"{os.path.basename(wm_img_path).split('.')[0]}_original.png")
            if wm_pipe is None:
                wm_pipe = GTWatermarkCachedLearnedWM(device, batch['watermarking_mask'], wm_pattern=args.wm_pattern, w_radius=args.wm_radius,shape = args.shape)
            
            if args.wm_radius == 10:
                gt_patch = batch['gt_patch'].detach().clone().to(device)
            else:
                gt_patch = wm_pipe.gt_patch

            if img_idx >= args.image_num:
                logging.info(f'===== Image Idx: {img_idx} =====')

                gt_img_tensor = batch['image'].to(device)
                init_latents_approx = batch['latents'].detach().clone().to(device)
                # gt_patch = batch['gt_patch'].detach().clone().to(device)

                fft_init_latents_approx = torch.fft.fftshift(torch.fft.fft2(init_latents_approx), dim=(-1, -2))

                fft_init_latents_approx_with_patch = fft_init_latents_approx
                fft_init_latents_approx_with_patch[wm_pipe.watermarking_mask] = gt_patch[wm_pipe.watermarking_mask]

                learned_wm_real = invertible_map_real.forward(fft_init_latents_approx_with_patch.real)
                learned_wm_imag = invertible_map_imag.forward(fft_init_latents_approx_with_patch.imag)

                learned_wm = torch.complex(learned_wm_real, learned_wm_imag)


                init_latents_wm = wm_pipe.inject_watermark(learned_wm, init_latents_approx)

                wm_pipe.gt_patch = learned_wm

                if cfgs['empty_prompt']:
                    wm_img_tensor = pipe(prompt, guidance_scale=1.0, num_inference_steps=50, output_type='tensor', use_trainable_latents=True, init_latents=init_latents_wm).images
                else:
                    wm_img_tensor = pipe(prompt, num_inference_steps=50, output_type='tensor', use_trainable_latents=True, init_latents=init_latents_wm).images
                
                logging.info(f'===== Adaptive Enhancement =====')
                save_img(original_img_path, wm_img_tensor, pipe)
                ssim_value_orig = pytorch_ssim.ssim(wm_img_tensor, gt_img_tensor).item()
                logging.info(f'Original SSIM {ssim_value_orig}')

                if args.use_adaptive_enhancement:
                    enhanced_img_tensor = adaptive_enhancement(gt_img_tensor, wm_img_tensor, float(ssim_threshold/100))
                else:
                    enhanced_img_tensor = wm_img_tensor

                ssim_value, psnr_value, det_prob, l1_value, wm_classification, lpips = compute_quality_metrics(enhanced_img_tensor, gt_img_tensor, pipe, wm_pipe, threshold=args.detection_threshold, device=device)
                ssim_value_base, psnr_value_base, det_prob_base, l1_value_base, wm_classification_base, lpips_base = compute_quality_metrics(gt_img_tensor, gt_img_tensor, pipe, wm_pipe, threshold=args.detection_threshold, device=device)
                logging.info(f'SSIM {ssim_value}, PSNR, {psnr_value}, Detect Prob: {det_prob}, LPIPS:{lpips}, WM_Class: {wm_classification}, L1: {l1_value} after postprocessing')
                logging.info(f'SSIM {ssim_value_base}, PSNR, {psnr_value_base}, Detect Prob: {det_prob_base}, LPIPS:{lpips_base}, WM_Class: {wm_classification_base}, L1: {l1_value_base} after postprocessing')

                list_of_all_attacks['no_attack'].append(det_prob)
                list_of_all_attacks_base['no_attack'].append(det_prob_base)

                imagename_base = f'{img_idx}_base.png'
                base_img_path = os.path.join(wm_path, f"{imagename_base.split('.')[0]}_100_SSIM{ssim_threshold}.png")

                save_img(enhanced_img_path, enhanced_img_tensor, pipe)
                save_img(base_img_path, gt_img_tensor, pipe)

                avg_psnr += psnr_value
                avg_ssim += ssim_value
                avg_det_prob += det_prob
                avg_wm_classification += wm_classification
                avg_lpips += lpips
                avg_l1 += l1_value

                avg_psnr_base += psnr_value_base
                avg_ssim_base += ssim_value_base
                avg_det_prob_base += det_prob_base
                avg_wm_classification_base += wm_classification_base
                avg_lpips_base += lpips_base
                avg_l1_base += l1_value_base

                logging.info(f'Avg SSIM {avg_ssim}, PSNR, {avg_psnr}, Detect Prob: {avg_det_prob}, LPIPS:{avg_lpips}, WM_Class: {avg_wm_classification} L1: {avg_l1} after postprocessing')
                logging.info(f'Avg Base SSIM {avg_ssim_base}, PSNR, {avg_psnr_base}, Detect Prob: {avg_det_prob_base}, LPIPS:{avg_lpips_base}, WM_Class: {avg_wm_classification_base}, L1: {avg_l1_base} after postprocessing')

                # logging.info(f'===== Single Attacks (Saves Images!) =====')
                
                single_attacks(cfgs, device, wm_path, imagename, ssim_threshold, att_pipe)
                single_attacks(cfgs, device, wm_path, imagename_base, ssim_threshold, att_pipe)

                # logging.info(f'===== Combined Attacks (Saves Images!) =====')
                combined_attacks(cfgs, device, wm_path, imagename, ssim_threshold, att_pipe)
                combined_attacks(cfgs, device, wm_path, imagename_base, ssim_threshold, att_pipe)

                results_detection_dict = detect_watermarks(cfgs, wm_path, imagename, ssim_threshold, pipe, wm_pipe, device, threshold=args.detection_threshold)
                results_detection_dict_base = detect_watermarks(cfgs, wm_path, imagename_base, ssim_threshold, pipe, wm_pipe, device, threshold=args.detection_threshold)

            else:
                continue

            logging.info(f'===== Watermark Detection Results =====')
            for key in avg_attack_det_metrics.keys():
                avg_attack_det_metrics[key] += results_detection_dict[key][0]
                avg_attack_wm_metrics[key] += results_detection_dict[key][1]
                
                avg_attack_det_metrics_base[key] += results_detection_dict_base[key][0]
                avg_attack_wm_metrics_base[key] += results_detection_dict_base[key][1]

                list_of_all_attacks[key].append(results_detection_dict[key][0])
                list_of_all_attacks_base[key].append(results_detection_dict_base[key][0])
            logging.info(f'Post Attacks Detection Metrics: {avg_attack_det_metrics}')
            logging.info(f'Post Attacks Detection Metrics Base: {avg_attack_det_metrics_base}')
            logging.info(f'Post Attacks WDR Metrics: {avg_attack_wm_metrics}')
            logging.info(f'Post Attacks WDR Metrics Base: {avg_attack_wm_metrics_base}')

            with open(args.write_file, "a") as file1:
                # Writing data to a file
                file1.write(f'{img_idx}\n')
                file1.write(f'Original SSIM {ssim_value_orig}\n')
                file1.write(f'AvgSSIM {avg_ssim}, PSNR, {avg_psnr}, Detect Prob: {avg_det_prob}, LPIPS:{avg_lpips}, WM_Class: {avg_wm_classification}, L1: {avg_l1} after postprocessing\n')
                file1.write(f'Base - SSIM {avg_ssim_base}, PSNR, {psnr_value_base}, Detect Prob: {det_prob_base}, LPIPS:{lpips_base}, WM_Class: {wm_classification_base} after postprocessing\n')
                file1.write(f'Average Post Attacks Detection Metrics: {avg_attack_det_metrics}\n')
                file1.write(f'Post Attacks Detection Metrics Base: {avg_attack_det_metrics_base}\n')
                file1.write(f'Post Attacks WDR Metrics: {avg_attack_wm_metrics}\n')
                file1.write(f'Post Attacks WDR Metrics Base: {avg_attack_wm_metrics_base}\n\n')

            for key in list(list_of_all_attacks.keys()):
                no_w_metrics = list_of_all_attacks_base[key]
                w_metrics = list_of_all_attacks[key]
                preds = no_w_metrics +  w_metrics
                t_labels = [0] * len(no_w_metrics) + [1] * len(w_metrics)

                auc, acc, low = compute_auc_tpr(t_labels, preds)
                logging.info(f'{key} - auc: {auc}, acc: {acc}, TPR@1%FPR: {low}')

                with open(args.write_file, "a") as file1:
                    file1.write(f'{key} - auc: {auc}, acc: {acc}, TPR@1%FPR: {low}\n')

    N = len(dataloader)
    avg_psnr /= N
    avg_ssim /= N
    avg_det_prob /= N
    logging.info(f'Average SSIM {avg_ssim}, Average PSNR, {avg_psnr}, Average Detect Prob: {avg_det_prob} after postprocessing')

    for key in avg_attack_det_metrics.keys():
        avg_attack_det_metrics[key] /= N
    logging.info(f'Average Post Attacks Detection Metrics: {avg_attack_det_metrics}')

    with open(args.write_file, "a") as file1:
        # Writing data to a file
        file1.write("FINAL.\n")
        file1.write(f'Average SSIM {avg_ssim}, Average PSNR, {avg_psnr}, Average Detect Prob: {avg_det_prob} after postprocessing\n')
        file1.write(f'Average Post Attacks Detection Metrics: {avg_attack_det_metrics}\n\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='diffusion watermark')
    # setup params
    parser.add_argument ('--seed', default=0, type=int)
    parser.add_argument ('--device', default='cuda')
    parser.add_argument ('--dataset', default='diffusiondb', choices=['coco', 'diffusiondb', 'wikiart', 'all'])
    parser.add_argument ('--dataset_path', default='/localhome/data/datasets/watermarking')
    parser.add_argument ('--cfg_path', default='./example/config/config_cached.yaml')
    parser.add_argument ('--num_workers', default=1, type=int)
    parser.add_argument ('--wm_pattern', default='rings', type=str)
    parser.add_argument ('--num_images', default=300, type=int)
    parser.add_argument ('--starting_image', default=0, type=str)
    parser.add_argument ("--save_model_dir_real", type=str)
    parser.add_argument ("--save_model_dir_imag", type=str)
    parser.add_argument ("--use_adaptive_enhancement", action="store_true", default=False)
    parser.add_argument ("--layers", type=int, default=2)
    parser.add_argument ("--wm_path", type=str)
    parser.add_argument ("--write_file", type=str, default='temp')
    parser.add_argument ("--method", type=str, default='residual')
    parser.add_argument ('--ssim_threshold', type=float, default=92)
    parser.add_argument ('--wm_radius', type=int, default=10)
    parser.add_argument ('--shape', type=str, default='circle')
    parser.add_argument ("--image_num", type=int, default=0)
    parser.add_argument ("--use_cached_latents", action="store_true", default=False)
    parser.add_argument ("--batch_size", type=int, default=1)
    parser.add_argument ("--detection_threshold", type=float, nargs='+', default=0.9, help='detection threshold for attacks')

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    eval(args)
