import argparse, os, glob, datetime, yaml, sys
sys.path.append('.')
sys.path.append('./src/taming-transformers')
print(sys.path)
import logging
import torch
import torch.nn as nn
import time
import random
import numpy as np
from tqdm import trange
# from pytorch_lightning import seed_everything
from qdiff.utils import seed_everything
from omegaconf import OmegaConf
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
from ldm.util import instantiate_from_config

from pytorch_fid.fid_score import calculate_fid_given_paths
from scripts.test import test_IS 
from evalution.sfid import test_fid_sfid
from quant.quant_layer import QuantModule, UniformAffineQuantizer    
from quant.quant_block import QuantSMVMatMul, QuantQKMatMul, BaseQuantBlock, QuantResBlock, QuantAttentionBlock
from quant import (
    QAModel,
    set_smooth_quantize_params_LDM,
    set_weight_quantize_params_LDM,
    set_act_quantize_params_LDM,
    Change_LDM_model_attnblock,
    recon_Qmodel,
)

logger = logging.getLogger(__name__)
rescale = lambda x: (x + 1.) / 2.

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-r","--resume",type=str,)
    parser.add_argument("-n","--n_samples",type=int,default=10000)
    parser.add_argument("-e","--eta",type=float,default=1.0)
    parser.add_argument("-v","--vanilla_sample",default=False,action='store_true',)
    parser.add_argument("--seed",type=int,default=1234,)
    parser.add_argument("-l","--logdir",type=str,default="none")
    parser.add_argument("--dataset",default="./LSUN/bedrooms/")
    parser.add_argument("-c","--custom_steps",default=50)
    parser.add_argument("--batch_size",type=int,default=10)
    parser.add_argument("--dpm", action="store_true",)
    # linear quantization configs
    parser.add_argument("--qat", action="store_true", help="apply quantization-aware training")
    parser.add_argument("--quant_act", action="store_true", )
    parser.add_argument("--weight_bit",type=int,default=8,)
    parser.add_argument("--act_bit",type=int,default=8,)
    parser.add_argument("--quant_mode", type=str, default="qdiff", choices=["qdiff"], )
    parser.add_argument("--split", action="store_true",)

    # qdiff specific configs
    parser.add_argument("--device", type=str,default="cuda:0",)
    parser.add_argument("--a_sym", action="store_true",)
    parser.add_argument("--sm_abit",type=int, default=8,)
    parser.add_argument("--split", action="store_true",)
    parser.add_argument("--verbose", action="store_true",)
    parser.add_argument("--calib_num_samples",default=1024,type=int,)
    parser.add_argument("--batch_samples",default=1024,type=int,)
    parser.add_argument("--class_cond", action="store_true",help="class difusion")
    parser.add_argument("--lr_w",type=float,default=1e-2,)
    parser.add_argument("--lr_za",type=float,default=1e-1,)
    parser.add_argument("--lr_a",type=float,default=1e-4,)
    parser.add_argument("--lr_rw",type=float,default=1e-3,)
    parser.add_argument("--smooth_type",type=str,default="weight-aware",)
    return parser

def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def custom_to_np(x):
    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
    sample = x.detach().cpu()
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    return sample

def logs2pil(logs, keys=["sample"]):
    imgs = dict()
    for k in logs:
        try:
            if len(logs[k].shape) == 4:
                img = custom_to_pil(logs[k][0, ...])
            elif len(logs[k].shape) == 3:
                img = custom_to_pil(logs[k])
            else:
                print(f"Unknown format for key {k}. ")
                img = None
        except:
            img = None
        imgs[k] = img
    return imgs

@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
               verbose=True,
               make_prog_row=False):


    if not make_prog_row:
        return model.p_sample_loop(None, shape,
                                   return_intermediates=return_intermediates, verbose=verbose)
    else:
        return model.progressive_denoising(
            None, shape, verbose=True
        )

@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0):
    ddim = DDIMSampler(model)
    ddim.quant_sample = model.quant_sample
    bs = shape[0]
    shape = shape[1:]
    samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
    return samples, intermediates

@torch.no_grad()
def convsample_dpm(model, steps, shape, eta=1.0
                    ):
    dpm = DPMSolverSampler(model)
    bs = shape[0]
    shape = shape[1:]
    samples, intermediates = dpm.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
    return samples, intermediates

@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0, dpm=False):
    log = dict()
    shape = [batch_size,
             model.model.diffusion_model.in_channels,
             model.model.diffusion_model.image_size,
             model.model.diffusion_model.image_size]

    # with model.ema_scope("Plotting"):
    with torch.no_grad():
        t0 = time.time()
        if vanilla:
            sample, progrow = convsample(model, shape,
                                            make_prog_row=True)
        elif dpm:
            logger.info(f'Using DPM sampling with {custom_steps} sampling steps and eta={eta}')
            sample, intermediates = convsample_dpm(model,  steps=custom_steps, shape=shape,
                                                    eta=eta)
        else:
            sample, intermediates = convsample_ddim(model,  steps=custom_steps, shape=shape,
                                                    eta=eta)
        t1 = time.time()
        x_sample = model.decode_first_stage(sample)
    torch.cuda.empty_cache()
    log["sample"] = x_sample
    log["time"] = t1 - t0
    log['throughput'] = sample.shape[0] / (t1 - t0)
    # logger.info(f'Throughput for this batch: {log["throughput"]}')
    return log

def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, 
    n_samples=50000, dpm=False):
    if vanilla:
        logger.info(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
    else:
        logger.info(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')

    tstart = time.time()
    # n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
    n_saved = 0
    # path = logdir
    if model.cond_stage_model is None:
        all_images = []

        logger.info(f"Running unconditional sampling for {n_samples} samples")
        with torch.no_grad():
            for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
                logs = make_convolutional_sample(model, batch_size=batch_size,
                                                vanilla=vanilla, custom_steps=custom_steps,
                                                eta=eta, dpm=dpm)
                n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
                torch.cuda.empty_cache()
    else:
       raise NotImplementedError('Currently only sampling for unconditional models supported.')
    logger.info(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")

def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
    for k in logs:
        if k == key:
            batch = logs[key]
            if np_path is None:
                for x in batch:
                    img = custom_to_pil(x)
                    imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
                    img.save(imgpath)
                    n_saved += 1
            else:
                npbatch = custom_to_np(batch)
                shape_str = "x".join([str(x) for x in npbatch.shape])
                nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
                np.savez(nppath, npbatch)
                n_saved += npbatch.shape[0]
    return n_saved

def load_model_from_config(config, sd):
    model = instantiate_from_config(config)
    model.load_state_dict(sd,strict=False)
    model.cuda()
    model.eval()
    return model

def load_model(config, ckpt, gpu, eval_mode):
    if ckpt:
        logger.info(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt, map_location="cpu")
        global_step = pl_sd["global_step"]
    else:
        pl_sd = {"state_dict": None}
        global_step = None
    model = load_model_from_config(config.model,
                                   pl_sd["state_dict"])
    return model, global_step

def block_train_w(q_unet, args, kwargs, cali_data, t, cali_t):
    recon_qnn = recon_Qmodel(args, q_unet, kwargs)

    q_unet.block_count = 0
    '''weight'''
    kwargs['cali_data'] = (cali_data, t)
    kwargs['cali_t'] = cali_t
    recon_qnn.kwargs = kwargs
    recon_qnn.down_name = None
    q_unet.set_steps_state(is_mix_steps=True)
    q_unet = recon_qnn.w_recon()
    q_unet.set_steps_state(is_mix_steps=False)
    torch.cuda.empty_cache()

if __name__ == "__main__":
    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    sys.path.append(os.getcwd())
    command = " ".join(sys.argv)

    parser = get_parser()
    opt, unknown = parser.parse_known_args()
    ckpt = None

    # fix random seed
    seed_everything(opt.seed)
    # torch.cuda.set_device(opt.device)
    print(torch.cuda.current_device())

    if not os.path.exists(opt.resume):
        raise ValueError("Cannot find {}".format(opt.resume))
    if os.path.isfile(opt.resume):
        # paths = opt.resume.split("/")
        try:
            logdir = '/'.join(opt.resume.split('/')[:-1])
            print(f'Logdir is {logdir}')
        except ValueError:
            paths = opt.resume.split("/")
            idx = -2  # take a guess: path/to/logdir/checkpoints/model.ckpt
            logdir = "/".join(paths[:idx])
        ckpt = opt.resume
    else:
        assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
        logdir = opt.resume.rstrip("/")
        ckpt = os.path.join(logdir, "model.ckpt")

    base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
    opt.base = base_configs

    configs = [OmegaConf.load(cfg) for cfg in opt.base]
    cli = OmegaConf.from_dotlist(unknown)
    config = OmegaConf.merge(*configs, cli)

    gpu = True
    if gpu:
        device = "cuda"
    eval_mode = True

    logdir = os.path.join(opt.logdir, "samples", now)
    os.makedirs(logdir)
    log_path = os.path.join(logdir, "run.log")
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(log_path),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger(__name__)

    logger.info(75 * "=")
    logger.info(f"Host {os.uname()[1]}")
    logger.info("logging to:")
    imglogdir = os.path.join(logdir, "img")
    opt.image_folder = imglogdir
    os.makedirs(imglogdir)
    logger.info(logdir)
    logger.info(75 * "=")

    model, global_step = load_model(config, ckpt, gpu, eval_mode)
    logger.info(f"global step: {global_step}")
    logger.info("Switched to EMA weights")
    model.model_ema.store(model.model.parameters())
    model.model_ema.copy_to(model.model)
    args = opt

    wq_params = {'n_bits': args.weight_bit, 'symmetric': False, 'channel_wise': True, 'scale_method': 'max'}
    aq_params = {'n_bits': args.act_bit, 'symmetric': args.a_sym, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param': args.quant_act, "prob": 1.0, "num_timesteps": args.custom_steps, "smooth_type": 'weight-aware'}#time_mean, 
    
    q_unet = QAModel(model.model.diffusion_model, args, wq_params=wq_params, aq_params=aq_params)

    if opt.qat:
        print("Setting the first and the last layer to 8-bit")
        q_unet.set_first_last_layer_to_8bit()
        # q_unet.disable_network_output_quantization()
        q_unet.set_quant_state(False, False)

        model.model.diffusion_model = q_unet
        print("sampling calib data")
        model.model.diffusion_model.set_quant_state(False, False)

        if os.path.exists('./Church100_Cali_data.pth'):
            samples, ts = torch.load('./QAT-DM/Church100_Cali_data.pth')
        else:
            shape = [args.batch_samples,
                    model.model.diffusion_model.in_channels,
                    model.model.diffusion_model.image_size,
                    model.model.diffusion_model.image_size]
            ddim = DDIMSampler(model)
            bs = shape[0]
            shape = shape[1:]
            samples = []
            ts = []
            with torch.no_grad():
                for i in tqdm(range(int(args.calib_num_samples/args.batch_samples)), desc="Generating image samples for cali-data"):
                    sample, intermediates = ddim.sample(args.custom_steps, batch_size=bs, shape=shape, eta=args.eta, verbose=False,)

                    samples.append(intermediates['x_inter'][:-1])                                    
                    ts.append(intermediates['ts'])
                torch.cuda.empty_cache()
            torch.save((samples, ts), os.path.join('.', "Church100_Cali_data.pth"))
        all_samples = []
        all_ts = []
        for t_sample in range(args.custom_steps):
            t_samples = torch.cat([sample[t_sample].to(device) for sample in samples])
            all_samples.append(t_samples)
            t_ts = torch.cat([t[t_sample].to(device) for t in ts])
            all_ts.append(t_ts)
        samples = None
        torch.cuda.empty_cache()

        all_cali_data = []
        all_t = []
        all_index = []
        all_cali_t = []
        for now_rt, sample_t in enumerate(all_samples):
            idx = torch.randperm(sample_t.size(0))[:64]
            cali_data = sample_t[idx]
            t = all_ts[now_rt][idx]
            cali_t = torch.full_like(t, now_rt)
            index = (args.custom_steps-1)-cali_t
            all_cali_data.append(cali_data.cpu())
            all_t.append(t.cpu())
            all_cali_t.append(cali_t.cpu())
            all_index.append(index.cpu())
        del(all_samples)

        if args.split:
            model.model.diffusion_model.model.split_shortcut = True

        cali_data = torch.cat(all_cali_data)
        t = torch.cat(all_t)
        index = torch.cat(all_index)
        idx = torch.randperm(len(cali_data))[:5120]
        cali_data = cali_data[idx]
        t = t[idx]
        index = index[idx]
        cali_data = (cali_data, t, index)

        '''Init scale_smooth'''
        set_smooth_quantize_params_LDM(model, cali_data, args)
        model.model.diffusion_model.set_smooth_state(set_smooth_weight=True, set_smooth_training=False)

        '''Init scale_w'''
        set_weight_quantize_params_LDM(model, cali_data, args)
        '''Init scale_a'''
        set_act_quantize_params_LDM(model, all_cali_data, all_t, all_index, args)

        Change_LDM_model_attnblock(model.model.diffusion_model, aq_params)

        '''block-wise training For other layers'''
        kwargs = dict(iters=5000,
                        act_quant=True, 
                        weight_quant=True, 
                        asym=True,
                        opt_mode='mse', 
                        lr_a=args.lr_a,
                        lr_w=args.lr_w,
                        lr_rw=args.lr_rw,
                        lr_za=args.lr_za,
                        lr_smooth=1e-5,
                        p=2.0,
                        weight=0.01,
                        b_range=(20,2), 
                        warmup=0.2,
                        batch_size=32,
                        batch_size1=64,
                        input_prob=1.0,
                        recon_w=True,
                        recon_a=True,
                        recon_rw=True,
                        recon_smooth=False,
                        keep_gpu=False,
                        device_loss=1,
                        add_loss=0,
                        )
        model.model.diffusion_model.set_quant_state(weight_quant=True, act_quant=args.quant_act)
        all_cali_data = torch.cat(all_cali_data).cpu()
        all_t = torch.cat(all_t).cpu()
        all_cali_t = torch.cat(all_cali_t).cpu()
        idx = torch.randperm(len(all_cali_data))[:5120]
        cali_data = all_cali_data[idx].clone()
        t = all_t[idx].clone()
        cali_t = all_cali_t[idx].clone()
        del all_cali_data, all_t, all_cali_t

        block_train_w(model.model.diffusion_model, args, kwargs, cali_data, t, cali_t)
        model.model.diffusion_model.set_quant_state(weight_quant=True, act_quant=args.quant_act)

    # write config out
    sampling_file = os.path.join(logdir, "sampling_config.yaml")
    sampling_conf = vars(opt)

    with open(sampling_file, 'a+') as f:
        yaml.dump(sampling_conf, f, default_flow_style=False)
    if opt.verbose:
        print(sampling_conf)
        logger.info("first_stage_model")
        logger.info(model.first_stage_model)
        logger.info("UNet model")
        logger.info(model.model)

    if opt.qat:
        model.quant_sample = True
    else:
        model.quant_sample = False

    # seed_everything(1234+9)
    run(model, imglogdir, eta=opt.eta,
        vanilla=opt.vanilla_sample,  n_samples=opt.n_samples, custom_steps=opt.custom_steps,
        batch_size=opt.batch_size, dpm=opt.dpm)

    logger.info("done.")
