import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import torch
import logging
logger = logging.getLogger(__name__)
torch._dynamo.config.cache_size_limit = 64
import cv2
import os.path as osp
import logging
from transformers import LogitsProcessor,LogitsProcessorList
from collections import Counter
from infinity.models.basic import *
from extended_watermark_processor import WatermarkDetector, WatermarkLookupProcessor, WatermarkNBitProcessor
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w
from tools.helper import save_single_image, joint_vi_vae_encode_decode, set_seeds
from torchvision.transforms import v2
from pydantic.utils import deep_update
import random
import time
def detect(
    args, 
    img_path: str,
    watermark_detector: WatermarkDetector,
    vae=None,
    watermark_scales: list= [],
    detect_on_each_scale: bool = False
)-> dict:
    """_summary_

    Args:
        args: Passing all file arguments
        img_path (str): Path to image to detect the watermark on
        watermark_detector (WatermarkDetector): Watermark detector
        vae (_type_, optional): VAE Encoder and quantizer. Defaults to None.
        watermark_scales (list, optional): On which scales the watermark has been applied. Defaults to [].

    Returns:
        dict: Metrics regarding dectection (green tokens count, z-value..)
    """
    metrics = {"stat_data": {}}
    scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]["scales"]
    scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
    tgt_h, tgt_w = dynamic_resolution_h_w[args.h_div_w_template][args.pn]["pixel"]       
    start = time.time()
    _, _, encoding_bit_indices, _= joint_vi_vae_encode_decode(
        vae, img_path, scale_schedule, "cuda", tgt_h, tgt_w, apply_spatial_patchify=args.apply_spatial_patchify
    )    
    if watermark_scales: # Otherwise it does not detect the watermark at all if no scales are specified
        encoding_bit_indices = [encoding_bit_indices[i] for i in watermark_scales] 
    
    
    encoding_bit_indices_flattened = torch.cat([t.reshape(-1) for t in encoding_bit_indices], dim=0)
    watermark_metrics = watermark_detector.detect(tokenized_text=encoding_bit_indices_flattened)
    metrics["time"] = time.time()-start
    watermark_detector.last_message_index = 0
    if watermark_scales:
        start_scale = watermark_scales[0]
    else:
        start_scale = 0

    if detect_on_each_scale: # Detects separately, increases inference time 
        for i, scale in enumerate(encoding_bit_indices): 
            flattened_scale = torch.cat([t.reshape(-1) for t in scale], dim=0)
            result = watermark_detector.detect(tokenized_text=flattened_scale)
            watermark_detector.last_message_index += len(flattened_scale)
            metrics["stat_data"][f"scale_{i+start_scale}"] = {'scale': {}}
            metrics["stat_data"][f"scale_{i+start_scale}"]['scale'] = {"z_score":round(result["z_score"], 5), "green_fraction": round(result["green_fraction"],5)}
        watermark_metrics = deep_update(watermark_metrics, metrics)
    watermark_detector.last_message_index = 0
    
    return watermark_metrics

def get_watermark_scales(watermark_scales_config, h_div_w_template, pn):
        scale_schedule = dynamic_resolution_h_w[h_div_w_template][pn]["scales"]
        match watermark_scales_config: 
            case 7: watermark_scales = range(len(scale_schedule))[:12]
            case 6: watermark_scales = range(len(scale_schedule))[:11]
            case 5: watermark_scales = range(len(scale_schedule))[5:]
            case 4: watermark_scales = range(len(scale_schedule))[10:]
            case 3: watermark_scales = range(len(scale_schedule))[:10]
            case 2: watermark_scales = range(len(scale_schedule)) # all scales
            case 1: watermark_scales = [len(scale_schedule)-1] # last scale
            case 0: watermark_scales = [] # No watermarking
            case _: raise(NotImplementedError)
        return watermark_scales

class WatermarkInference():
    def __init__(self,args):
        self.delta = args.watermark_delta
        self.context_width = args.watermark_context_width
        self.green_list= set(args.set.split(','))
        self.scales = get_watermark_scales(args.watermark_scales, args.h_div_w_template, args.pn)
        if self.scales == []:
            assert self.delta == 0, 'Delta must be 0 if no watermark is applied.'
        self.count_bit_flip=args.watermark_count_bit_flip

        self.method = args.watermark_method
        match(self.method):
            case '2-bit_pattern':
                self.message = None
                self.logits_processor = WatermarkLookupProcessor(vocab=[0,1], device = "cuda", delta=self.delta, green_list=self.green_list, context_width = self.context_width) 
            case _:
                raise NotImplementedError
        if self.scales == []:
            self.logits_processor = None
        
        

def get_detector(args, message = None):
    watermark_detector = WatermarkDetector(
        vocab=[0,1],
        gamma=0.5,
        delta=args.watermark_delta,
        seeding_scheme=args.watermark_seeding_scheme,
        device="cuda",
        tokenizer=None,
        z_threshold=4.0,
        normalizers=[],
        ignore_repeated_ngrams=False,
        context_width = args.watermark_context_width,
        message = message,
        green_list=args.set
    )

    return watermark_detector


def load_visual_tokenizer(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # load vae
    if args.vae_type in [14,16,18,20,24,32,64]:
        from infinity.models.bsq_vae.vae import vae_model
        schedule_mode = "dynamic"
        codebook_dim = args.vae_type
        codebook_size = 2**codebook_dim
        if args.apply_spatial_patchify:
            patch_size = 8
            encoder_ch_mult=[1, 2, 4, 4]
            decoder_ch_mult=[1, 2, 4, 4]
        else:
            patch_size = 16
            encoder_ch_mult=[1, 2, 4, 4, 4]
            decoder_ch_mult=[1, 2, 4, 4, 4]
        vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, 
                        encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
    else:
        raise ValueError(f'vae_type={args.vae_type} not supported')
    return vae

if __name__ == '__main__':
    import json
    import csv
    from tqdm import tqdm
    from random import shuffle
    parser = argparse.ArgumentParser()
    parser.add_argument('--vae_type', type=int, default=1)
    parser.add_argument('--vae_path', type=str, default='')
    parser.add_argument('--pn', type=str, required=True, choices=['0.06M', '0.25M', '1M'])
    parser.add_argument('--h_div_w_template', type=float, default=1.000)
    parser.add_argument('--apply_spatial_patchify', type=int, default=0, choices=[0,1])
    parser.add_argument("--watermark_context_width", type=int, default=4)
    parser.add_argument("--watermark_seeding_scheme", type=str, default="selfhash")
    parser.add_argument("--num_samples", type=int)
    parser.add_argument("--dataset_path", type=str)
    parser.add_argument("--out_dir", type=str)
    parser.add_argument("--seed", type=int, default=0)


    args = parser.parse_args()
    set_seeds(args.seed)
    args.watermark_delta = 0

    message = None
    vae = load_visual_tokenizer(args)
    watermark_detector = get_detector(args, message)
    watermark_scales = get_watermark_scales(2, args.h_div_w_template,args.pn)
    img_paths = []
    z_scores = []
    green_fractions = []
    for file in os.listdir(args.dataset_path):
        if file.lower().endswith(('png', 'jpg', 'jpeg')):
            img_paths.append(f"{args.dataset_path}/{file}") 
    if os.path.isdir(args.dataset_path):
        for subdir in os.listdir(args.dataset_path):
            subdir_path = os.path.join(args.dataset_path, subdir)
            if os.path.isdir(subdir_path):
                for file in os.listdir(subdir_path):
                    if file.lower().endswith(('png', 'jpg', 'jpeg')):
                        img_paths.append(os.path.join(subdir_path, file))
    
    shuffle(img_paths)
    if args.num_samples >0:
        if args.num_samples < len(img_paths):
            img_paths = img_paths[:args.num_samples]

    
    for img_path in tqdm(img_paths):
        metrics = detect(args,img_path, watermark_detector, vae, watermark_scales, False)
        z_scores.append(metrics['z_score'])
        green_fractions.append(metrics["green_fraction"])

    output = {'Generated MS-COCO':z_scores}
    with open(f"{args.out_dir}/z_scores.json", "a") as f:
        jsonl_list = json.dumps(output, default=str) + "\n"
        f.writelines(jsonl_list)   
    z_scores = np.array(z_scores)
    green_fractions = np.array(green_fractions)

    z_mean = round(z_scores.mean().item(),4)
    z_std = round(z_scores.std().item(),4)
    g_mean = round(green_fractions.mean().item(),4)
    g_std = round(green_fractions.std().item(),4)

    with open(f"{args.out_dir}/z_scores_stats.csv", "a", newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["z_mean", "z_std", "g_mean", "g_std"])
        writer.writerow([z_mean, z_std, g_mean, g_std])
    # export in a csv file: 
    # export z_scores mean
    # export z_scores std