import json
import logging
logging.basicConfig(encoding="utf-8", level=logging.WARNING)
logger = logging.getLogger(__name__)

import cv2
import torch
import traceback
from tqdm import tqdm
from infinity.models.basic import *
import detect_watermark
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
from detect_watermark import WatermarkInference, get_detector
from tools.helper import set_seeds,save_single_image, count_matching_n_bit_sequences, save_images, get_stripped_delta, count_match_after_reencoding
from tools.scales_injector import ScalesInjector
torch._dynamo.config.cache_size_limit = 64
from pydantic.utils import deep_update
from torch.utils.data import Dataset, DataLoader
from run_infinity import *
import random

def process_short_text(short_text):
    if '--' in short_text:
        processed_text = short_text.split('--')[0]
        if processed_text:
            short_text = processed_text
    return short_text



lines4infer = []
prompt_list = []

# Define your dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        # Assuming data is a list of dictionaries
        self.data = data

    def __len__(self):
        # Return the number of samples
        return len(self.data)

    def __getitem__(self, idx):
        # Retrieve a single data point at the given index
        # Return it in the format you need (it may consist of multiple items)
        return self.data[idx]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_common_arguments(parser)
    parser.add_argument("--rewrite_prompt", type=int, default=0, choices=[0, 1])
    parser.add_argument("--coco30k_prompts", type=int, default=0, choices=[0, 1])
    parser.add_argument("--save4fid_eval", type=int, default=0, choices=[0, 1])
    parser.add_argument("--save_recons_img", type=int, default=0, choices=[0, 1])
    parser.add_argument("--jsonl_filepath", type=str, default="")
    parser.add_argument("--long_caption_fid", type=int, default=0, choices=[0, 1])
    parser.add_argument("--max_samples", type=int, default=1000)
    parser.add_argument("--out_dir", type=str, default="./")
    parser.add_argument("--dataset_path", type=str)
    args = parser.parse_args()
    set_seeds(args.seed)
    args.cfg = list(map(float, args.cfg.split(",")))
    out_dir = args.out_dir 
    stripped_delta = get_stripped_delta(args.watermark_delta)
    os.makedirs(out_dir, exist_ok=True)
    
    print(f"save to {out_dir}")

    if len(args.cfg) == 1:
        args.cfg = args.cfg[0]

    if args.coco30k_prompts:
        id2caption = get_coco_30k_captions()
        captions = []
        ids = []
        lines4infer = []
        for d in id2caption.items():
            ids.append(d[0])
            captions.append(d[1])
        random.shuffle(captions)
        lines4infer = [
            {"prompt": prompt, "h_div_w": 1.0, "infer_type": "infer/coco30k_prompt"}
            for prompt in captions
        ]
    if args.jsonl_filepath:
        lines4infer = []
        image_duplicates = []
        skipped_images = 0
        with open(args.jsonl_filepath, "r") as f:
            cnt = 0
            meta = json.load(f)
            annotations = meta["annotations"]
            random.shuffle(annotations)
            for annotation in annotations:
                if cnt== args.max_samples:
                    break
                #assert osp.exists(gt_image_path), gt_image_path
                if annotation["image_id"] in image_duplicates:
                    continue
                image_duplicates.append(annotation["image_id"])
                if args.long_caption_fid:
                    prompt = annotation["long_caption"]
                else:
                    prompt = annotation["caption"]
                if not prompt:
                    continue
                image_path = osp.join(out_dir, f"{annotation['image_id']}.png")
                if args.watermark_gen_image:
                    if os.path.exists(image_path):
                        skipped_images +=1
                        cnt+=1
                        continue
                else:
                    if not os.path.exists(image_path):
                        skipped_images +=1
                        continue
                
                lines4infer.append(
                    {
                        "image_id": annotation["image_id"],
                        "prompt_id": annotation["id"],
                        "prompt": prompt,
                        "h_div_w": 1.0,
                        "infer_type": "val/coco2014",
                        'image_path': image_path
                    }
                )
                cnt+=1

    print(f"Skipped {skipped_images} images as these already existed")
    
    print(f"Totally {len(lines4infer)} items for infer")

    # load text encoder
    text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)
    # load vae
    vae = load_visual_tokenizer(args)
    # load infinity
    infinity = load_transformer(vae, args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    jsonl_list = []
    infer_full_data = {}
    cnt = 0
    watermark_inference = WatermarkInference(args) 
    if watermark_inference.message == None:
        message = None
    else:
        message = watermark_inference.message[0,...]
    watermark_detector = get_detector(args, message)
    if args.watermark_gen_image:
        jsonl_file = f"{out_dir}/metrics.json"
    else:
        jsonl_file = f"{out_dir}/metrics_detect.json"

    batch_size = args.batch_size

    batches = [lines4infer[i:i+batch_size] for i in range(0, len(lines4infer),batch_size)]
    inference_time = []
    detection_time = []
    for i, batch in enumerate(tqdm(batches, miniters=50)):
        try:
            logging.info(batch)
            prompts = []
            for entry in batch:
                prompt = entry["prompt"]
                prompt = process_short_text(prompt)
                prompts.append(prompt)

            h_div_w_template = h_div_w_templates[
                np.argmin(np.abs(h_div_w_templates - 1)) 
            ]
            scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]["scales"]
            scale_schedule = [(1, h, w) for (t, h, w) in scale_schedule]

            if args.apply_spatial_patchify:
                vae_scale_schedule = [
                    (pt, 2 * ph, 2 * pw) for pt, ph, pw in scale_schedule
                ]
            else:
                vae_scale_schedule = scale_schedule
            tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][args.pn]["pixel"]
            scales_injector = None
            
            gen_bit_indices = False
            if args.watermark_gen_image:
                ret, gen_bit_indices, image = gen_one_img(
                    infinity,
                    vae,
                    text_tokenizer,
                    text_encoder,
                    prompt=prompts,
                    g_seed=args.seed,
                    gt_leak=0,
                    gt_ls_Bl=None,
                    cfg_list=args.cfg,
                    tau_list=args.tau,
                    scale_schedule=scale_schedule,
                    cfg_insertion_layer=[args.cfg_insertion_layer],
                    vae_type=args.vae_type,
                    sampling_per_bits=args.sampling_per_bits,
                    enable_positive_prompt=args.enable_positive_prompt,
                    watermark=watermark_inference,
                    scales_injector=scales_injector
                    )

                save_single_image(image, [entry["image_path"] for entry in batch])
            for i in range(len(batch)):
                metadata = batch[i]
                detect_results = detect_watermark.detect(args, metadata['image_path'], watermark_detector=watermark_detector, vae=vae, watermark_scales=watermark_inference.scales, detect_on_each_scale = True)                
                
                metadata["z_score"] = round(detect_results["z_score"],5)
                metadata["green_fraction"] = round(detect_results["green_fraction"], 3)
                metadata["stat_data"] = {} 

                if args.watermark_gen_image:
                    metadata["stat_data"] = deep_update(metadata["stat_data"], ret[i]["stat_data"])

                metadata["stat_data"] = deep_update(metadata["stat_data"], detect_results["stat_data"])
                if args.watermark_count_bit_loss_after_reencoding:
                    gt_img, recons_img, encoding_bit_indices, _ = joint_vi_vae_encode_decode(
                        vae, metadata["image_path"], scale_schedule, "cuda", tgt_h=tgt_h, tgt_w=tgt_w, apply_spatial_patchify=args.apply_spatial_patchify
                    )
                    current_gen_bit_indices = [indices[i,::] for indices in gen_bit_indices]

                    ret_count, num_matches_list, num_total_list = count_match_after_reencoding(
                        encoding_bit_indices, current_gen_bit_indices, watermark_inference.scales, compare_only_on_watermarked_scales=False # Maybe set to something else?
                        )
                    metadata["stat_data"] = deep_update(metadata["stat_data"],ret_count)
               
                jsonl_list = json.dumps(metadata, default=str) + "\n"
                with open(jsonl_file, 'a') as f:
                    f.writelines(jsonl_list)        
        except Exception as e:
            logger.warning(f"{e}", traceback.print_exc())
            logger.warning(f"Error at batch {i}: {batch}")
    exit(0) 
    tmp_inf = inference_time
    tmp_det = detection_time
    print("Without warmup")
    inference_time = tmp_inf[:1000]
    detection_time = tmp_det[:1000]

    print(f"Time/Image: {np.mean(inference_time)} ({np.std(inference_time)}); max time: {max(inference_time)}, min time: {min(inference_time)}")
    print(f"Detection/Image: {np.mean(detection_time)} ({np.std(detection_time)}), max time: {max(detection_time)}, min time: {min(detection_time)}")

    print("With warmup")
    inference_time = tmp_inf[50:]
    detection_time = tmp_det[50:]
    print(f"Time/Image: {np.mean(inference_time)} ({np.std(inference_time)}); max time: {max(inference_time)}, min time: {min(inference_time)}")
    print(f"Detection/Image: {np.mean(detection_time)} ({np.std(detection_time)}), max time: {max(detection_time)}, min time: {min(detection_time)}")



    if args.save4fid_eval:
        print(out_dir)
        fid, _ = calculate_fid(
            out_dir, get_coco_fid_stats()
        )

