import os
import os.path as osp
import hashlib
import time
import argparse
import json
import shutil
import glob
import re
import sys

import cv2
from tqdm.auto import tqdm
import torch
import numpy as np
from pytorch_lightning import seed_everything

from run_infinity_b import *
from conf import HF_TOKEN, HF_HOME
from transformers import BlipForConditionalGeneration,BlipProcessor

# set environment variables
os.environ['HF_TOKEN'] = HF_TOKEN
os.environ['HF_HOME'] = HF_HOME
os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1'


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    add_common_arguments(parser)
    parser.add_argument('--out_dir', type=str, default='')
    parser.add_argument('--n_samples', type=int, default=1)
    parser.add_argument('--metadata_file', type=str, default='evaluation/image_reward/benchmark-prompts.json')
    parser.add_argument('--rewrite_prompt', type=int, default=0, choices=[0,1])
    ###
    parser.add_argument('--noise_apply_layers',type=int,default=0)
    parser.add_argument('--noise_apply_requant',type=int,default=1)
    parser.add_argument('--noise_apply_strength',type=float,default=0.3)
    parser.add_argument('--debug_bsc',type=int,default=0)
    ###
    args = parser.parse_args()

    # parse cfg
    args.cfg = list(map(float, args.cfg.split(',')))
    if len(args.cfg) == 1:
        args.cfg = args.cfg[0]
    
    with open(args.metadata_file) as fp:
        metadatas = json.load(fp)

    if args.model_type == 'sdxl':
        from diffusers import DiffusionPipeline
        base = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
        ).to("cuda")
        refiner = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            text_encoder_2=base.text_encoder_2,
            vae=base.vae,
            torch_dtype=torch.float16,
            use_safetensors=True,
            variant="fp16",
        ).to("cuda")
    elif args.model_type == 'sd3':
        from diffusers import StableDiffusion3Pipeline
        pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
        pipe = pipe.to("cuda")
    elif args.model_type == 'pixart_sigma':
        from diffusers import PixArtSigmaPipeline
        pipe = PixArtSigmaPipeline.from_pretrained(
            "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
        ).to("cuda")
    elif args.model_type == 'flux_1_dev':
        from diffusers import FluxPipeline
        pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
    elif args.model_type == 'flux_1_dev_schnell':
        from diffusers import FluxPipeline
        pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")
    elif 'infinity' in args.model_type:
        # load text encoder
        text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)
        # load vae
        vae = load_visual_tokenizer(args)
        # load infinity
        infinity = load_transformer(vae, args)
        if args.rewrite_prompt:
            from tools.prompt_rewriter import PromptRewriter
            prompt_rewriter = PromptRewriter(system='', few_shot_history=[])
    
    
    os.makedirs(args.out_dir,exist_ok=True)
    save_metadatas = []
    
    
    #####
    blip_processor = BlipProcessor.from_pretrained("blip-image-captioning-large")
    blip_model = BlipForConditionalGeneration.from_pretrained("blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
    #####
    
    #####
    swinir_config = {
        "target": "infinity.models.swinir.SwinIR",  
        "params": {                               
            "img_size": 64,
            "patch_size": 1,
            "in_chans": 3,
            "embed_dim": 180,
            "depths": [6, 6, 6, 6, 6, 6, 6, 6],
            "num_heads": [6, 6, 6, 6, 6, 6, 6, 6],
            "window_size": 8,
            "mlp_ratio": 2,
            "sf": 8,
            "img_range": 1.0,
            "upsampler": "nearest+conv",
            "resi_connection": "1conv",
            "unshuffle": True,
            "unshuffle_scale": 8
        }
    }    
    swinir: SwinIR = instantiate_from_config(swinir_config)
    sd = torch.load('weights/general_swinir_v1.ckpt', map_location="cpu")
    if "state_dict" in sd:
        sd = sd["state_dict"]
    sd = {
        (k[len("module.") :] if k.startswith("module.") else k): v
        for k, v in sd.items()
    }
    swinir.load_state_dict(sd, strict=True)
    for p in swinir.parameters():
        p.requires_grad = False
    swinir.eval().to("cuda")
    #####
    
    for index, metadata in tqdm(enumerate(metadatas)):
        seed_everything(args.seed)
        
        lq_img_path = metadata['lq_img_path']
        prompt = metadata.get('prompt', None)
        img_name = os.path.relpath(lq_img_path, start=os.path.dirname(lq_img_path))
        sample_path = os.path.join(args.out_dir, img_name)

        tau = args.tau
        cfg = args.cfg
        if args.rewrite_prompt:
            refined_prompt = prompt_rewriter.rewrite(prompt)
            input_key_val = extract_key_val(refined_prompt)
            prompt = input_key_val['prompt']
            print(f'prompt: {prompt}, refined_prompt: {refined_prompt}')
        
        images = []
        bitwise_self_correction= BitwiseSelfCorrection(vae, args)
        for _ in range(args.n_samples):   #####n_samples==1
            t1 = time.time()
            if args.model_type == 'sdxl':
                image = base(
                    prompt=prompt,
                    num_inference_steps=40,
                    denoising_end=0.8,
                    output_type="latent",
                ).images
                image = refiner(
                    prompt=prompt,
                    num_inference_steps=40,
                    denoising_start=0.8,
                    image=image,
                ).images[0]
            elif args.model_type == 'sd3':
                image = pipe(
                    prompt,
                    negative_prompt="",
                    num_inference_steps=28,
                    guidance_scale=7.0,
                    num_images_per_prompt=1,
                ).images[0]
            elif args.model_type == 'flux_1_dev':
                image = pipe(
                    prompt,
                    height=1024,
                    width=1024,
                    guidance_scale=3.5,
                    num_inference_steps=50,
                    max_sequence_length=512,
                    num_images_per_prompt=1,
                ).images[0]
            elif args.model_type == 'flux_1_dev_schnell':
                image = pipe(
                    prompt,
                    height=1024,
                    width=1024,
                    guidance_scale=0.0,
                    num_inference_steps=4,
                    max_sequence_length=256,
                    generator=torch.Generator("cpu").manual_seed(0)
                ).images[0]
            elif args.model_type == 'pixart_sigma':
                image = pipe(prompt).images[0]
            elif 'infinity' in args.model_type:
                h_div_w_template = 1.000
                scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales']
                scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
                tgt_h, tgt_w = dynamic_resolution_h_w[h_div_w_template][args.pn]['pixel']
                # image,prompt = gen_one_img_eval(infinity, 
                #                     vae, 
                #                     text_tokenizer, 
                #                     text_encoder,
                #                     prompt, 
                #                     tau_list=tau, 
                #                     cfg_sc=3, 
                #                     cfg_list=cfg, 
                #                     scale_schedule=scale_schedule, 
                #                     cfg_insertion_layer=[args.cfg_insertion_layer], 
                #                     vae_type=args.vae_type, 
                #                     lq_img_path=lq_img_path,
                #                     args=args,
                #                     blip_model=blip_model,
                #                     blip_processor=blip_processor,
                #                     )
                image,prompt = gen_one_img_eval_long(infinity, 
                                    vae, 
                                    text_tokenizer, 
                                    text_encoder,
                                    prompt, 
                                    tau_list=tau, 
                                    cfg_sc=3, 
                                    cfg_list=cfg, 
                                    scale_schedule=scale_schedule, 
                                    cfg_insertion_layer=[args.cfg_insertion_layer], 
                                    vae_type=args.vae_type, 
                                    lq_img_path=lq_img_path,
                                    args=args,
                                    blip_model=blip_model,
                                    blip_processor=blip_processor,
                                    swinir=swinir,
                                    bitwise_self_correction=bitwise_self_correction
                                    )
            else:
                raise ValueError
            t2 = time.time()
            images.append(image)
        
        
        for i, image in enumerate(images):
            if 'infinity' in args.model_type:
                cv2.imwrite(sample_path, image.cpu().numpy())
            else:
                image.save(sample_path)
                
        metadata['prompt']=prompt
        save_metadatas.append(metadata)

    save_metadata_file_path = os.path.join(os.path.dirname(args.metadata_file), "metadata_w_prompt.json")
    with open(save_metadata_file_path, "w") as fp:
        json.dump(save_metadatas, fp)



