import argparse, os, datetime, yaml, sys
sys.path.append('.')
sys.path.append('./src/taming-transformers')
print(sys.path)
import logging
import cv2
import gc
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from imwatermark import WatermarkEncoder
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
# from pytorch_lightning import seed_everything
from qdiff.utils import seed_everything
from torch import autocast
from contextlib import nullcontext
import gc

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim_control import DDIMSampler_control
from quant_control.recon_Qmodel import recon_Qmodel
from evalution.sfid import test_fid_sfid

from quant import (
    QAModel,
    set_smooth_quantize_params_Conditional,
    set_act_quantize_params_Conditional,
    set_weight_quantize_params_Conditional,
    Change_LDM_model_SpatialTransformer,
)

from quant.quant_layer import QuantModule, UniformAffineQuantizer    
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor

logger = logging.getLogger(__name__)

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images

def load_model_from_config(config, ckpt, device, verbose=False):
    logging.info(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        logging.info(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        logging.info("missing keys:")
        logging.info(m)
    if len(u) > 0 and verbose:
        logging.info("unexpected keys:")
        logging.info(u)

    # device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    return model

def put_watermark(img, wm_encoder=None):
    if wm_encoder is not None:
        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        img = wm_encoder.encode(img, 'dwtDct')
        img = Image.fromarray(img[:, :, ::-1])
    return img

def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x

def check_safety(x_image):
    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    assert x_checked_image.shape[0] == len(has_nsfw_concept)
    for i in range(len(has_nsfw_concept)):
        if has_nsfw_concept[i]:
            x_checked_image[i] = load_replacement(x_checked_image[i])
    return x_checked_image, has_nsfw_concept

def get_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument("-l","--logdir",default="none")
    parser.add_argument("--dataset",default="/dataset/imagenet/train_resize")
    parser.add_argument("--skip_grid",action='store_true',)
    parser.add_argument("--skip_save",action='store_true',)
    parser.add_argument("--ddim_steps",type=int,default=20,)
    parser.add_argument("--ddim_eta",type=float,default=0.0,)
    parser.add_argument("--n_samples",type=int,default=3,)
    parser.add_argument("--n_batch",type=int,default=2,)
    parser.add_argument("--n_rows",type=int,default=0,)
    parser.add_argument("--scale",type=float,default=3.0,)
    parser.add_argument("--config",type=str,default="configs/latent-diffusion/cin256-v2.yaml",)
    parser.add_argument("--ckpt",type=str,default="models/ldm/cin256/model.ckpt",)
    parser.add_argument("--seed",type=int,default=1234,help="the seed (for reproducible sampling)",)
    parser.add_argument("--precision",type=str,choices=["full", "autocast"],default="autocast")
    parser.add_argument("--no_grad_ckpt", action="store_true",help="disable gradient checkpointing")
    # linear quantization configs
    parser.add_argument("--qat", action="store_true", help="apply quantization-aware training")
    parser.add_argument("--quant_act", action="store_true", )
    parser.add_argument("--weight_bit",type=int,default=8,)
    parser.add_argument("--act_bit",type=int,default=8,)
    parser.add_argument("--quant_mode", type=str, default="qdiff", choices=["qdiff"], )
    parser.add_argument("--split", action="store_true",)

    # qdiff specific configs
    parser.add_argument("--device", type=str,default="cuda:0",)
    parser.add_argument("--a_sym", action="store_true",)
    parser.add_argument("--sm_abit",type=int, default=8,)
    parser.add_argument("--split", action="store_true",)
    parser.add_argument("--verbose", action="store_true",)
    parser.add_argument("--calib_num_samples",default=1024,type=int,)
    parser.add_argument("--batch_samples",default=1024,type=int,)
    parser.add_argument("--cond", action="store_true",help="class difusion")
    parser.add_argument("--lr_w",type=float,default=1e-2,)
    parser.add_argument("--lr_za",type=float,default=1e-1,)
    parser.add_argument("--lr_a",type=float,default=1e-4,)
    parser.add_argument("--lr_rw",type=float,default=1e-3,)
    parser.add_argument("--smooth_type",type=str,default="weight-aware",)
    return parser

def block_train_w(q_unet, args, kwargs, cali_data, t, index, cond, uncond, cali_t):
    recon_qnn = recon_Qmodel(args, q_unet, kwargs)
    q_unet.block_count = 0
    '''weight'''
    kwargs['cali_data'] = (cali_data, t, index, cond, uncond)
    kwargs['cali_t'] = cali_t
    recon_qnn.kwargs = kwargs
    recon_qnn.down_name = None
    q_unet.set_steps_state(is_mix_steps=True)
    q_unet = recon_qnn.w_recon()
    q_unet.set_steps_state(is_mix_steps=False)
    torch.cuda.empty_cache()

def main():
    parser = get_parser()
    opt, unknown = parser.parse_known_args()

    seed_everything(opt.seed)
    device = torch.device(opt.device) if torch.cuda.is_available() else torch.device("cpu")

    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    logdir = os.path.join(opt.logdir, "samples", now)
    os.makedirs(logdir)
    log_path = os.path.join(logdir, "run.log")
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(log_path),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger(__name__)

    logger.info(75 * "=")
    logger.info(f"Host {os.uname()[1]}")
    logger.info("logging to:")
    imglogdir = os.path.join(logdir, "img")
    opt.image_folder = imglogdir
    os.makedirs(imglogdir)
    logger.info(logdir)
    logger.info(75 * "=")

    config = OmegaConf.load(f"{opt.config}")
    model = load_model_from_config(config, f"{opt.ckpt}", device=device)
    model = model.to(device)

    batch_size = opt.n_batch
    n_rows = opt.n_rows if opt.n_rows > 0 else opt.n_samples

    classes = range(1000)   # define classes to be sampled here
    n_samples_per_class = int(opt.n_samples/len(classes))
    xc_all = []
    for class_label in classes:
        xc = torch.tensor(n_samples_per_class*[class_label])
        xc_all.append(xc)
    data = torch.hstack(xc_all)
    data_randperm = torch.randperm(data.size(0))
    data = torch.tensor(data[data_randperm]).to(device)
    opt.data  = data
    assert(opt.cond)
    args = opt
    args.custom_steps = args.ddim_steps
    
    wq_params = {'n_bits': args.weight_bit, 'symmetric': False, 'channel_wise': True, 'scale_method': 'max'}
    aq_params = {'n_bits': args.act_bit, 'symmetric': False, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param': args.quant_act, "prob": 1.0, "num_timesteps": args.custom_steps, "smooth_type": 'weight-aware'}#time_mean, 

    q_unet = QAModel(model.model.diffusion_model, args, wq_params=wq_params, aq_params=aq_params)

    if opt.qat:
        print("Setting the first and the last layer to 8-bit")
        q_unet.set_first_last_layer_to_8bit()
        q_unet.set_quant_state(False, False)

        model.model.diffusion_model = q_unet
        print("sampling calib data")
        model.model.diffusion_model.set_quant_state(False, False)

        if os.path.exists('./ImageNet20_Cali_data.pth'):
            samples, ts, conds, unconds = torch.load('./ImageNet20_Cali_data.pth')
        else:
            uc = model.get_learned_conditioning(
                    {model.cond_stage_key: torch.tensor(args.calib_num_samples*[1000]).to(model.device)}
                    )
            xc = args.data[:args.calib_num_samples]
            c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
            shape = [3, 64, 64]
                
            sampler = DDIMSampler_control(model)
            samples = []
            ts = []
            conds = []
            unconds = []

            with torch.no_grad():
                for i in tqdm(range(int(args.calib_num_samples/args.batch_samples)), desc="Generating image samples for cali-data"):
                    _, intermediates = sampler.sample(S=args.ddim_steps,
                                                    conditioning=c[i*args.batch_samples:(i+1)*args.batch_samples],
                                                    batch_size=args.batch_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=args.scale,
                                                    unconditional_conditioning=uc[i*args.batch_samples:(i+1)*args.batch_samples],
                                                    eta=args.ddim_eta)
                    samples.append(intermediates['x_inter'][:-1])                                    
                    ts.append(intermediates['ts'])
                    conds.append(intermediates['cond'])
                    unconds.append(intermediates['uncond'])
                torch.cuda.empty_cache()
            torch.save((samples, ts, conds, unconds), os.path.join('.', "ImageNet20_Cali_data.pth"))
        
        all_samples = []
        all_ts = []
        all_conds = []
        all_unconds = []
        for t_sample in range(args.custom_steps):
            t_samples = torch.cat([sample[t_sample].cpu() for sample in samples])
            all_samples.append(t_samples)
            t_ts = torch.cat([t[t_sample].cpu() for t in ts])
            all_ts.append(t_ts)
            t_conds = torch.cat([cond[t_sample].cpu() for cond in conds])
            all_conds.append(t_conds)
            t_unconds = torch.cat([uncond[t_sample].cpu() for uncond in unconds])
            all_unconds.append(t_unconds)
        samples, ts, conds, unconds = None, None, None, None
        torch.cuda.empty_cache()

        all_cali_data = []
        all_t = []
        all_index = []
        all_cond = []
        all_uncond = []
        all_cali_t = []
        for now_rt, sample_t in enumerate(all_samples):
            idx = torch.randperm(sample_t.size(0))[:256]
            cali_data = sample_t[idx]
            t = all_ts[now_rt][idx]
            cond = all_conds[now_rt][idx]
            uncond = all_unconds[now_rt][idx]
            cali_t = torch.full_like(t, now_rt)
            index = (args.custom_steps-1)-cali_t
            all_cali_data.append(cali_data.cpu())
            all_t.append(t.detach().cpu())
            all_index.append(index.detach().cpu())
            all_cond.append(cond.detach().cpu())
            all_uncond.append(uncond.detach().cpu())
            all_cali_t.append(cali_t.detach().cpu())
        del(all_samples, all_conds, all_unconds, all_ts)

        if args.split:
            model.model.diffusion_model.model.split_shortcut = True

        cali_data = torch.cat(all_cali_data)
        t = torch.cat(all_t)
        index = torch.cat(all_index)
        cond = torch.cat(all_cond)
        uncond = torch.cat(all_uncond)
        idx = torch.randperm(len(cali_data))[:5120]
        cali_data = cali_data[idx]
        t = t[idx]
        index = index[idx]
        cond = cond[idx]
        uncond = uncond[idx]
        cali_data = (cali_data, t, index, cond, uncond)

        '''Init scale_smooth'''
        set_smooth_quantize_params_Conditional(model, cali_data, args)
        model.model.diffusion_model.set_smooth_state(set_smooth_weight=True, set_smooth_training=False)

        '''Init scale_w'''
        set_weight_quantize_params_Conditional(model, cali_data, args)
        '''Init scale_a'''
        set_act_quantize_params_Conditional(model, all_cali_data, all_t, all_index, all_cond, all_uncond, args)

        Change_LDM_model_SpatialTransformer(model.model.diffusion_model, aq_params)

        '''block-wise training For other layers'''
        kwargs = dict(iters=5000,
                        act_quant=True, 
                        weight_quant=True, 
                        asym=True,
                        opt_mode='mse', 
                        lr_a=args.lr_a,
                        lr_w=args.lr_w,
                        lr_rw=args.lr_rw,
                        lr_za=args.lr_za,
                        lr_smooth=1e-5,
                        p=2.0,
                        weight=0.01,
                        b_range=(20,2), 
                        warmup=0.2,
                        batch_size=32,
                        batch_size1=32,
                        input_prob=1.0,
                        recon_w=True,
                        recon_a=True,
                        recon_rw=True,
                        recon_smooth=False,
                        keep_gpu=False,
                        )
        model.model.diffusion_model.set_quant_state(weight_quant=True, act_quant=args.quant_act)
        all_cali_data = torch.cat(all_cali_data).cpu()
        all_t = torch.cat(all_t).cpu()
        all_index = torch.cat(all_index).cpu()
        all_cond = torch.cat(all_cond).cpu()
        all_uncond = torch.cat(all_uncond).cpu()
        all_cali_t = torch.cat(all_cali_t).cpu()
        idx = torch.randperm(len(all_cali_data))[:1024]
        cali_data = all_cali_data[idx].clone()
        t = all_t[idx].clone()
        index = all_index[idx].clone()
        cond = all_cond[idx].clone()
        uncond = all_uncond[idx].clone()
        cali_t = all_cali_t[idx].clone()
        del all_cali_data, all_t, all_index, all_cond, all_uncond, all_cali_t

        block_train_w(model.model.diffusion_model, args, kwargs, cali_data, t, index, cond, uncond, cali_t)
        model.model.diffusion_model.set_quant_state(weight_quant=True, act_quant=args.quant_act)

    logging.info("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
    wm = "StableDiffusionV1"
    wm_encoder = WatermarkEncoder()
    wm_encoder.set_watermark('bytes', wm.encode('utf-8'))

    base_count = 0
    grid_count = 0
    if opt.verbose:
        logger.info("UNet model")
        logger.info(model.model)
    sampler = DDIMSampler_control(model)
    if opt.qat:
        sampler.quant_sample = True

    # seed_everything(1234+9)
    iterator = tqdm(range(1000), desc='DDIM Sampler')
    with torch.no_grad():
        with model.ema_scope():
            uc = model.get_learned_conditioning(
                {model.cond_stage_key: torch.tensor(opt.n_batch*[1000]).to(model.device)}
                )
            for i, class_num in enumerate(iterator):
                class_label = class_num
                print(f"rendering {opt.n_batch} examples of class '{class_label}' in {opt.ddim_steps} steps and using s={opt.scale:.2f}.")
                xc = torch.tensor(opt.n_batch*[class_label])
                c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
                
                samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                conditioning=c,
                                                batch_size=opt.n_batch,
                                                shape=[3, 64, 64],
                                                verbose=False,
                                                unconditional_guidance_scale=opt.scale,
                                                unconditional_conditioning=uc, 
                                                eta=opt.ddim_eta)

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                            min=0.0, max=1.0)
                # all_samples.append(x_samples_ddim)
                x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                x_checked_image = x_samples_ddim
                # x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)

                x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

                for x_sample in x_checked_image_torch:
                    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                    img = Image.fromarray(x_sample.astype(np.uint8))
                    img = put_watermark(img, wm_encoder)
                    img.save(os.path.join(imglogdir, f"{base_count:05}.png"))
                    base_count += 1
                if base_count > opt.n_samples:
                    break

    print("down!")

if __name__ == "__main__":
    main()
