import watermark_attacks
from tools import helper
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w

import json
import logging
logging.basicConfig(encoding="utf-8", level=logging.WARNING)
logger = logging.getLogger(__name__)
import torch
from tqdm import tqdm
from infinity.models.basic import *
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
from detect_watermark import WatermarkInference, get_detector
from tools.helper import set_seeds
torch._dynamo.config.cache_size_limit = 64

from run_infinity import *
import pandas as pd
import os
import torch
from sklearn import metrics
import numpy as np

def analyze_rotation(watermarked_folder_path, clean_folder_path, vae, watermark_detector, watermark_scales, args):
    scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]["scales"]
    scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
    tgt_h, tgt_w = dynamic_resolution_h_w[args.h_div_w_template][args.pn]["pixel"]

    clean_dataset = watermark_attacks.apply_attack(clean_folder_path, "none", args)
        
    clean_zscores = []
    clean_labels = []
    
    for i, image in enumerate(tqdm(clean_dataset, miniters=50)):
        gt_img, _, encoding_bit_indices, _= helper.joint_vi_vae_encode_decode(
            vae, image, scale_schedule, "cuda", tgt_h, tgt_w, apply_spatial_patchify=args.apply_spatial_patchify
        )
    
        if watermark_scales: # Otherwise it does not detect the watermark at all if no scales are specified
            
            encoding_bit_indices = [encoding_bit_indices[i] for i in watermark_scales] 
        
        encoding_bit_indices_flattened = torch.cat([t.reshape(-1) for t in encoding_bit_indices], dim=0)
        watermark_metrics = watermark_detector.detect(tokenized_text=encoding_bit_indices_flattened)
        clean_zscores.append(abs(watermark_metrics["z_score"]))
        clean_labels.append(0)        
        
    
    for degree in tqdm(range(0, 360, 10), miniters=50):
    
        args.mini = degree
        args.maxi = degree
    
        wmarked_z_scores = []
        wmarked_labels = []
    
        attacked_dataset = watermark_attacks.apply_attack(watermarked_folder_path, "rotate", args)
        
        for i, image in enumerate(tqdm(attacked_dataset, miniters=50)):
            gt_img, _, encoding_bit_indices, _= helper.joint_vi_vae_encode_decode(
                vae, image, scale_schedule, "cuda", tgt_h, tgt_w, apply_spatial_patchify=args.apply_spatial_patchify
            )
        
            if watermark_scales: # Otherwise it does not detect the watermark at all if no scales are specified
                encoding_bit_indices = [encoding_bit_indices[i] for i in watermark_scales] 
            
            encoding_bit_indices_flattened = torch.cat([t.reshape(-1) for t in encoding_bit_indices], dim=0)
            watermark_metrics = watermark_detector.detect(tokenized_text=encoding_bit_indices_flattened)
            wmarked_z_scores.append(abs(watermark_metrics["z_score"]))
            wmarked_labels.append(1)
                
        all_labels = clean_labels + wmarked_labels
        all_scores = clean_zscores + wmarked_z_scores
    
        fpr, tpr, threshold = metrics.roc_curve(all_labels, all_scores)
        auc = metrics.auc(fpr, tpr)
        acc = np.max(1 - (fpr + (1 - tpr))/2)
        print(f"Threshold: {threshold[np.where(fpr<.01)[0][-1]]}")
        low = tpr[np.where(fpr<.01)[0][-1]]
        
        print(f'For rotation {degree}: auc: {auc}, acc: {acc}, TPR@1%FPR: {low}')
        
        df = pd.DataFrame([{
            "Scale" : watermark_scales,
            "Degree": degree,
            "AUC": auc,
            "Accuracy": acc,
            "TPR@1%FPR": low,
            "z_score": round(np.mean(wmarked_z_scores).item(),4),
            "z_score_std": round(np.std(wmarked_z_scores).item(),4),
            "clean_z_score": round(np.mean(clean_zscores).item(),4),
            "clean_z_score_std": round(np.std(clean_zscores).item(),4)
        }])

        file_path = f"{watermarked_folder_path}/rotation_results.csv"
        write_header = not os.path.exists(file_path)

        df.to_csv(file_path, mode='a', header=write_header, index=False)

def tpr_fpr_robustness(watermarked_folder_path, clean_folder_path, vae, watermark_detector, watermark_scales, args, attack=None):

    scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]["scales"]
    scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
    tgt_h, tgt_w = dynamic_resolution_h_w[args.h_div_w_template][args.pn]["pixel"]

    clean_dataset = watermark_attacks.apply_attack(clean_folder_path, "none", args)
    
    if attack: 
        if type(attack) == str: 
            all_attacks = [attack]
        else:
            all_attacks = attack 
    else:
        all_attacks = ["none", "noise" , "blur", "color", "rotate", "crop", "jpeg"] 
    
    
    clean_zscores = []
    clean_labels = []
    
    for i, image in enumerate(tqdm(clean_dataset, miniters=50)):

        if i == 0: image.save(f"{watermarked_folder_path}/attacks/clean.png")
        gt_img, _, encoding_bit_indices, _= helper.joint_vi_vae_encode_decode(
            vae, image, scale_schedule, "cuda", tgt_h, tgt_w, apply_spatial_patchify=args.apply_spatial_patchify
        )
    
        if watermark_scales: # Otherwise it does not detect the watermark at all if no scales are specified
            
            encoding_bit_indices = [encoding_bit_indices[i] for i in watermark_scales] 
        
        encoding_bit_indices_flattened = torch.cat([t.reshape(-1) for t in encoding_bit_indices], dim=0)
        watermark_metrics = watermark_detector.detect(tokenized_text=encoding_bit_indices_flattened)
        clean_zscores.append(abs(watermark_metrics["z_score"]))
        clean_labels.append(0)        
        
    
    for attack in tqdm(all_attacks):
    
        wmarked_z_scores = []
        wmarked_labels = []
    
        attacked_dataset = watermark_attacks.apply_attack(watermarked_folder_path, attack, args)
        
        for i, image in enumerate(tqdm(attacked_dataset, miniters=50)):

            if i == 0: image.save(f"{watermarked_folder_path}/attacks/{attack}.png")
            
            gt_img, _, encoding_bit_indices, _= helper.joint_vi_vae_encode_decode(
                vae, image, scale_schedule, "cuda", tgt_h, tgt_w, apply_spatial_patchify=args.apply_spatial_patchify
            )
        
            if watermark_scales: # Otherwise it does not detect the watermark at all if no scales are specified
                encoding_bit_indices = [encoding_bit_indices[i] for i in watermark_scales] 
            
            encoding_bit_indices_flattened = torch.cat([t.reshape(-1) for t in encoding_bit_indices], dim=0)
            watermark_metrics = watermark_detector.detect(tokenized_text=encoding_bit_indices_flattened)
            wmarked_z_scores.append(abs(watermark_metrics["z_score"]))
            wmarked_labels.append(1)
                
        all_labels = clean_labels + wmarked_labels
        all_scores = clean_zscores + wmarked_z_scores
    
        fpr, tpr, threshold = metrics.roc_curve(all_labels, all_scores)
        auc = metrics.auc(fpr, tpr)
        acc = np.max(1 - (fpr + (1 - tpr))/2)
        print(f"Threshold: {threshold[np.where(fpr<.01)[0][-1]]}")
        low = tpr[np.where(fpr<.01)[0][-1]]
        
        print(f'For attack {attack}: auc: {auc}, acc: {acc}, TPR@1%FPR: {low}')
        
        df = pd.DataFrame([{
            "Scale" : watermark_scales,
            "Attack": attack,
            "AUC": auc,
            "Accuracy": acc,
            "TPR@1%FPR": low,
            "z_score": round(np.mean(wmarked_z_scores).item(),4),
            "z_score_std": round(np.std(wmarked_z_scores).item(),4),
            "clean_z_score": round(np.mean(clean_zscores).item(),4),
            "clean_z_score_std": round(np.std(clean_zscores).item(),4)
        }])

        file_path = f"{watermarked_folder_path}/robustness_results.csv"
        write_header = not os.path.exists(file_path)

        df.to_csv(file_path, mode='a', header=write_header, index=False)
        
        
        
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    add_common_arguments(parser)
    parser.add_argument("--watermarked_dir", type=str)
    parser.add_argument("--clean_dir", type=str)
    parser.add_argument("--stable_diff_vae", type=str, default='')
    parser.add_argument("--num_samples", type=int, default=-1) # on how many samples the attacks are supposed to happen
    
    args = parser.parse_args()
    set_seeds(args.seed)
    
    # load text encoder
    text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)
    # load vae
    vae = load_visual_tokenizer(args)
    # load infinity
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    watermark_detector = get_detector(args)
    watermark_scales = detect_watermark.get_watermark_scales(args.watermark_scales, args.h_div_w_template, args.pn)
    
    os.makedirs(f"{args.watermarked_dir}/attacks", exist_ok=True)
    print(f"Test robustness between clean dir:{args.clean_dir}, and watermarked dir:{args.watermarked_dir} for the watermark scales: {args.watermark_scales}")
    
    set_seeds(args.seed)
    tpr_fpr_robustness(watermarked_folder_path=args.watermarked_dir, clean_folder_path=args.clean_dir, vae=vae, watermark_detector=watermark_detector, watermark_scales=watermark_scales, args=args, attack=None) # runs all attacks