import argparse, os, datetime, yaml, sys
sys.path.append('.')
sys.path.append('./src/taming-transformers')
print(sys.path)
import logging
import cv2
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from imwatermark import WatermarkEncoder
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
import random
# from pytorch_lightning import seed_everything
from qdiff.utils import seed_everything
from torch import autocast
from contextlib import nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim_control import DDIMSampler_control
from ldm.models.diffusion.plms import PLMSSampler

from quant_control.coco_prompt import get_prompts, center_resize_image
from quant_control.recon_Qmodel import recon_Qmodel
from evalution.sfid import test_fid_sfid

from quant import (
    QAModel,
    set_smooth_quantize_params_Stable,
    set_act_quantize_params_Stable,
    set_weight_quantize_params_Stable,
    Change_LDM_model_SpatialTransformer,

)
from quant.quant_layer import QuantModule, UniformAffineQuantizer    
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor

logger = logging.getLogger(__name__)

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images

def load_model_from_config(config, ckpt, verbose=False):
    logging.info(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        logging.info(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        logging.info("missing keys:")
        logging.info(m)
    if len(u) > 0 and verbose:
        logging.info("unexpected keys:")
        logging.info(u)

    model.eval()
    return model

def put_watermark(img, wm_encoder=None):
    if wm_encoder is not None:
        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        img = wm_encoder.encode(img, 'dwtDct')
        img = Image.fromarray(img[:, :, ::-1])
    return img

def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x

def check_safety(x_image):
    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    assert x_checked_image.shape[0] == len(has_nsfw_concept)
    for i in range(len(has_nsfw_concept)):
        if has_nsfw_concept[i]:
            x_checked_image[i] = load_replacement(x_checked_image[i])
    return x_checked_image, has_nsfw_concept

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt",type=str,default="a painting of a virus monster playing guitar",)
    parser.add_argument("-l","--logdir",default="none")
    parser.add_argument("--dataset",default="/dataset/imagenet/train_resize")
    parser.add_argument("--skip_grid",action='store_true',)
    parser.add_argument("--skip_save",action='store_true',)
    parser.add_argument("--ddim_steps",type=int,default=20,)
    parser.add_argument("--plms",action='store_true',help="use plms sampling",)
    parser.add_argument("--laion400m",action='store_true',help="uses the LAION400M model",)
    parser.add_argument("--fixed_code",action='store_true',help="if enabled, uses the same starting code across samples ",)
    parser.add_argument("--ddim_eta",type=float,default=0.0,)
    parser.add_argument("--n_iter",type=int,default=1,help="sample this often",)
    parser.add_argument("--H",type=int,default=512)
    parser.add_argument("--W", type=int,default=512,)
    parser.add_argument("--C",type=int,default=4,)
    parser.add_argument("--f",type=int,default=8,help="downsampling factor",)
    parser.add_argument("--n_samples",type=int,default=3,)
    parser.add_argument("--n_batch",type=int,default=2,)
    parser.add_argument("--n_rows",type=int,default=0,)
    parser.add_argument("--scale",type=float,default=7.5,)
    parser.add_argument("--from-file",type=str,help="if specified, load prompts from this file",)
    parser.add_argument("--config",type=str,default="configs/stable-diffusion/v1-inference.yaml",help="path to config which constructs model",)
    parser.add_argument("--ckpt",type=str,default="models/ldm/stable-diffusion-v1/model.ckpt",help="path to checkpoint of model",)
    parser.add_argument("--seed",type=int,default=1234,)
    parser.add_argument("--precision",type=str,choices=["full", "autocast"],default="autocast")
    parser.add_argument("--no_grad_ckpt", action="store_true",help="disable gradient checkpointing")
    # linear quantization configs
    parser.add_argument("--qat", action="store_true", help="apply quantization-aware training")
    parser.add_argument("--quant_act", action="store_true", )
    parser.add_argument("--weight_bit",type=int,default=8,)
    parser.add_argument("--act_bit",type=int,default=8,)
    parser.add_argument("--quant_mode", type=str, default="qdiff", choices=["qdiff"], )
    parser.add_argument("--split", action="store_true",)
    # qdiff specific configs
    parser.add_argument("--device", type=str,default="cuda:0",)
    parser.add_argument("--a_sym", action="store_true",)
    parser.add_argument("--sm_abit",type=int, default=8,)
    parser.add_argument("--split", action="store_true",)
    parser.add_argument("--verbose", action="store_true",)
    parser.add_argument("--calib_num_samples",default=256,type=int,)
    parser.add_argument("--batch_samples",default=4,type=int,)
    parser.add_argument("--cond", action="store_true",help="class difusion")
    parser.add_argument("--lr_w",type=float,default=1e-2,)
    parser.add_argument("--lr_za",type=float,default=1e-1,)
    parser.add_argument("--lr_a",type=float,default=1e-4,)
    parser.add_argument("--lr_rw",type=float,default=1e-3,)
    parser.add_argument("--smooth_type",type=str,default="weight-aware",)
    return parser

def block_train_w(q_unet, args, kwargs, cali_data, t, index, cond, uncond, ts_next, cali_t):
    recon_qnn = recon_Qmodel(args, q_unet, kwargs)

    q_unet.block_count = 0
    '''weight'''
    kwargs['cali_data'] = (cali_data, t, index, cond, uncond, ts_next)
    kwargs['cali_t'] = cali_t
    recon_qnn.kwargs = kwargs
    recon_qnn.down_name = None
    q_unet.set_steps_state(is_mix_steps=True)
    q_unet = recon_qnn.w_recon()
    q_unet.set_steps_state(is_mix_steps=False)
    torch.cuda.empty_cache()

def main():
    parser = get_parser()
    opt, unknown = parser.parse_known_args()

    if opt.laion400m:
        print("Falling back to LAION 400M model...")
        opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
        opt.ckpt = "models/ldm/text2img-large/model.ckpt"
        opt.outdir = "outputs/txt2img-samples-laion400m"

    seed_everything(opt.seed)
    # torch.cuda.set_device(opt.device)
    device = torch.device(opt.device) if torch.cuda.is_available() else torch.device("cpu")

    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    logdir = os.path.join(opt.logdir, "samples", now)
    os.makedirs(logdir)
    log_path = os.path.join(logdir, "run.log")
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(log_path),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger(__name__)

    logger.info(75 * "=")
    logger.info(f"Host {os.uname()[1]}")
    logger.info("logging to:")
    imglogdir = os.path.join(logdir, "img")
    opt.image_folder = imglogdir
    os.makedirs(imglogdir)
    logger.info(logdir)
    logger.info(75 * "=")

    config = OmegaConf.load(f"{opt.config}")
    model = load_model_from_config(config, f"{opt.ckpt}")
    model = model.cuda()

    batch_size = opt.n_batch
    n_rows = opt.n_rows if opt.n_rows > 0 else opt.n_samples

    file_path = './COCO.txt'
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
    opt.list_prompts = [line.strip() for line in lines]
    data = opt.list_prompts

    base_count = 0
    prompt_path = "./image_prompts"

    for prompt in data:
        name = os.path.join(prompt_path, f"{base_count:05}.txt")
        file = open(name, 'w')
        file.write(prompt)
        file.close()
        base_count = base_count + 1
    assert(opt.cond)
    args = opt
    args.custom_steps = args.ddim_steps
    
    wq_params = {'n_bits': args.weight_bit, 'symmetric': False, 'channel_wise': True, 'scale_method': 'max'}
    aq_params = {'n_bits': args.act_bit, 'symmetric': False, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param': args.quant_act, "prob": 1.0, "num_timesteps": args.custom_steps, "smooth_type": 'weight-aware'}#time_mean, 

    q_unet = QAModel(model.model.diffusion_model, args, wq_params=wq_params, aq_params=aq_params)

    if opt.qat:
        print("Setting the first and the last layer to 8-bit")
        q_unet.set_first_last_layer_to_8bit()
        q_unet.set_quant_state(False, False)

        model.model.diffusion_model = q_unet
        print("sampling calib data")
        model.model.diffusion_model.set_quant_state(False, False)


        if os.path.exists('./Coco50_Cali_data.pth'):
            samples, ts, conds, unconds, ts_nexts = torch.load('./Coco50_Cali_data.pth')
        else:
            uc = model.get_learned_conditioning(args.calib_num_samples * [""])
            prompts = args.list_prompts[:args.calib_num_samples]
            c = model.get_learned_conditioning(prompts)
            shape = [args.C, args.H // args.f, args.W // args.f]
            start_code = None
            if args.plms:
                sampler = PLMSSampler(model)
            else:
                sampler = DDIMSampler_control(model)

            samples = []
            ts = []
            conds = []
            unconds = []
            ts_nexts = []

            with torch.no_grad():
                for i in tqdm(range(int(args.calib_num_samples/args.batch_samples)), desc="Generating image samples for cali-data"):
                    _, intermediates = sampler.sample(S=args.ddim_steps,
                                                    conditioning=c[i*args.batch_samples:(i+1)*args.batch_samples],
                                                    batch_size=args.batch_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=args.scale,
                                                    unconditional_conditioning=uc[i*args.batch_samples:(i+1)*args.batch_samples],
                                                    eta=args.ddim_eta,
                                                    x_T=start_code)
                
                    samples.append(intermediates['x_inter'][:-1])                                    
                    ts.append(intermediates['ts'])
                    conds.append(intermediates['cond'])
                    unconds.append(intermediates['uncond'])
                    ts_nexts.append(intermediates['ts_next'])
                torch.cuda.empty_cache()
            torch.save((samples, ts, conds, unconds, ts_nexts), os.path.join('.', "Coco50_Cali_data.pth"))
        
        all_samples = []
        all_ts = []
        all_conds = []
        all_unconds = []
        all_ts_nexts = []
        for t_sample in range(args.custom_steps):
            t_samples = torch.cat([sample[t_sample].cpu() for sample in samples])
            all_samples.append(t_samples)
            t_ts = torch.cat([t[t_sample].cpu() for t in ts])
            all_ts.append(t_ts)
            t_conds = torch.cat([cond[t_sample].cpu() for cond in conds])
            all_conds.append(t_conds)
            t_unconds = torch.cat([uncond[t_sample].cpu() for uncond in unconds])
            all_unconds.append(t_unconds)
            t_ts_nexts = torch.cat([ts_next[t_sample].cpu() for ts_next in ts_nexts])
            all_ts_nexts.append(t_ts_nexts)
        samples, ts, conds, unconds, ts_nexts = None, None, None, None, None
        torch.cuda.empty_cache()

        all_cali_data = []
        all_t = []
        all_index = []
        all_cond = []
        all_uncond = []
        all_ts_next = []
        all_cali_t = []
        for now_rt, sample_t in enumerate(all_samples):
            idx = torch.randperm(sample_t.size(0))[:16]
            cali_data = sample_t[idx]
            t = all_ts[now_rt][idx]
            cond = all_conds[now_rt][idx]
            uncond = all_unconds[now_rt][idx]
            ts_next = all_ts_nexts[now_rt][idx]
            cali_t = torch.full_like(t, now_rt)
            index = (args.custom_steps-1)-cali_t
            all_cali_data.append(cali_data.cpu())
            all_t.append(t.cpu())
            all_index.append(index.cpu())
            all_cond.append(cond.cpu())
            all_uncond.append(uncond.cpu())
            all_ts_next.append(ts_next.cpu())
            all_cali_t.append(cali_t.cpu())
        del(all_samples, all_conds, all_unconds, all_ts, all_ts_nexts)

        if args.split:
            model.model.diffusion_model.model.split_shortcut = True

        cali_data = torch.cat(all_cali_data)
        t = torch.cat(all_t)
        index = torch.cat(all_index)
        cond = torch.cat(all_cond)
        uncond = torch.cat(all_uncond)
        ts_next = torch.cat(all_ts_next)
        idx = torch.randperm(len(cali_data))[:512]
        cali_data = cali_data[idx]
        t = t[idx]
        index = index[idx]
        cond = cond[idx]
        uncond = uncond[idx]
        ts_next = ts_next[idx]
        cali_data = (cali_data, t, index, cond, uncond, ts_next)

        '''Init scale_smooth'''
        set_smooth_quantize_params_Stable(model, cali_data, args)
        model.model.diffusion_model.set_smooth_state(set_smooth_weight=True, set_smooth_training=False)

        '''Init scale_w'''
        set_weight_quantize_params_Stable(model, cali_data, args)
        '''Init scale_a'''
        set_act_quantize_params_Stable(model, all_cali_data, all_t, all_index, all_cond, all_uncond, all_ts_next, args)

        Change_LDM_model_SpatialTransformer(model.model.diffusion_model, aq_params)

        '''block-wise training For other layers'''
        kwargs = dict(iters=5000,
                        act_quant=True, 
                        weight_quant=True, 
                        asym=True,
                        opt_mode='mse', 
                        lr_a=args.lr_a,
                        lr_w=args.lr_w,
                        lr_rw=args.lr_rw,
                        lr_za=args.lr_za,
                        lr_smooth=1e-5,
                        p=2.0,
                        weight=0.01,
                        b_range=(20,2), 
                        warmup=0.2,
                        batch_size=4,
                        batch_size1=2,
                        input_prob=1.0,
                        recon_w=True,
                        recon_a=True,
                        recon_rw=True,
                        recon_smooth=False,
                        keep_gpu=False,
                        )
        model.model.diffusion_model.set_quant_state(weight_quant=True, act_quant=args.quant_act)
        all_cali_data = torch.cat(all_cali_data).cpu()
        all_t = torch.cat(all_t).cpu()
        all_index = torch.cat(all_index).cpu()
        all_cond = torch.cat(all_cond).cpu()
        all_uncond = torch.cat(all_uncond).cpu()
        all_ts_next = torch.cat(all_ts_next).cpu()
        all_cali_t = torch.cat(all_cali_t).cpu()
        idx = torch.randperm(len(all_cali_data))[:512]
        cali_data = all_cali_data[idx].clone()
        t = all_t[idx].clone()
        index = all_index[idx].clone()
        cond = all_cond[idx].clone()
        uncond = all_uncond[idx].clone()
        ts_next = all_ts_next[idx].clone()
        cali_t = all_cali_t[idx].clone()
        del all_cali_data, all_t, all_index, all_cond, all_uncond, all_cali_t, all_ts_next

        block_train_w(model.model.diffusion_model, args, kwargs, cali_data, t, index, cond, uncond, ts_next, cali_t)
        model.model.diffusion_model.set_quant_state(weight_quant=True, act_quant=args.quant_act)

    logging.info("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
    wm = "StableDiffusionV1"
    wm_encoder = WatermarkEncoder()
    wm_encoder.set_watermark('bytes', wm.encode('utf-8'))

    base_count = 0
    grid_count = 0

    if opt.verbose:
        logger.info("UNet model")
        logger.info(model.model)

    if args.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler_control(model)
    if opt.qat:
        sampler.quant_sample = True
    seed_everything(1234+9)

    start_code = None
    if opt.fixed_code:
        start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)

    precision_scope = autocast if opt.precision=="autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                tic = time.time()
                all_samples = list()
                for n in trange(opt.n_iter, desc="Sampling"):
                    # for i in tqdm(range(int(opt.n_samples/batch_size)), desc="samples"):
                    for i in tqdm(range(int(len(data)/batch_size)), desc="samples"):
                        prompts = data[i*batch_size : (i+1)*batch_size]
                        uc = None
                        if opt.scale != 1.0:
                            uc = model.get_learned_conditioning(batch_size * [""])
                        c = model.get_learned_conditioning(prompts)
                        shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
                        samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                         conditioning=c,
                                                         batch_size=batch_size,
                                                         shape=shape,
                                                         verbose=False,
                                                         unconditional_guidance_scale=opt.scale,
                                                         unconditional_conditioning=uc,
                                                         eta=opt.ddim_eta,
                                                         x_T=start_code)

                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                        x_checked_image = x_samples_ddim
                        # x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)

                        x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

                        if not opt.skip_save:
                            for x_sample in x_checked_image_torch:
                                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                                img = Image.fromarray(x_sample.astype(np.uint8))
                                img = put_watermark(img, wm_encoder)
                                img.save(os.path.join(imglogdir, f"{base_count:05}.png"))
                                base_count += 1

                        if not opt.skip_grid:
                            all_samples.append(x_checked_image_torch)

                if not opt.skip_grid:
                    # additionally, save as grid
                    grid = torch.stack(all_samples, 0)
                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                    grid = make_grid(grid, nrow=n_rows)

                    # to image
                    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                    img = Image.fromarray(grid.astype(np.uint8))
                    img = put_watermark(img, wm_encoder)
                    img.save(os.path.join(logdir, f'grid-{grid_count:04}.png'))
                    grid_count += 1
                toc = time.time()

    print("down!")

if __name__ == "__main__":
    main()
