import argparse, os, glob, datetime, yaml, sys
sys.path.append('.')
sys.path.append('./src/taming-transformers')
print(sys.path)
import logging
import math
import random
import copy
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.cuda import amp
# from pytorch_lightning import seed_everything
from quant.utils import seed_everything

from ddim.models.diffusion import Model
from ddim.datasets import inverse_data_transform
from ddim.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
from ddim.functions.ckpt_util import get_ckpt_path
from ddim.functions.denoising import generalized_steps, cali_generalized_steps, train_generalized_steps, important_generalized_last_steps, important_generalized_every_steps, important_generalized_every_steplayer

import torchvision.utils as tvu
from evalution.sfid import test_fid_sfid
from quant.quant_layer import QuantModule, UniformAffineQuantizer

from quant import (
    QAModel,
    set_act_quantize_params, 
    set_weight_quantize_params,
    set_smooth_quantize_params,
    set_smooth_quantize_params2,
    recon_Qmodel,
    recon_layer_Qmodel,
)

logger = logging.getLogger(__name__)

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to the config file")
    parser.add_argument("--seed", type=int, default=1234, help="Random seed")
    parser.add_argument("-l","--logdir", default="none")
    parser.add_argument("--use_pretrained", action="store_true")
    parser.add_argument("--sample_type",default="generalized",)
    parser.add_argument("--skip_type",default="uniform",)
    parser.add_argument("--timesteps", type=int, default=1000, help="number of steps involved")
    parser.add_argument("--eta",default=0.0,)
    parser.add_argument("--sequence", action="store_true")
    parser.add_argument("--qat", action="store_true", help="apply quantization-aware training")
    parser.add_argument("--quant_act", action="store_true", )
    parser.add_argument("--weight_bit",type=int,default=8,)
    parser.add_argument("--act_bit",type=int,default=8,)
    parser.add_argument("--quant_mode", type=str, default="qdiff", choices=["qdiff"], )
    parser.add_argument("--max_images", type=int, default=50000, help="number of images to sample")

    # qdiff specific configs
    parser.add_argument("--device", type=str,default="cuda:0",)
    parser.add_argument("--a_sym", action="store_true",)
    parser.add_argument("--sm_abit",type=int, default=8,)
    parser.add_argument("--split", action="store_true",)
    parser.add_argument("--verbose", action="store_true",)
    parser.add_argument("--calib_num_samples",default=1024,type=int,)
    parser.add_argument("--batch_samples",default=1024,type=int,)
    parser.add_argument("--class_cond", action="store_true",help="class difusion")
    parser.add_argument("--recon", action="store_true",)
    parser.add_argument("--block_recon", action="store_true",)
    parser.add_argument("--layer_recon", action="store_true",)
    parser.add_argument("--lr_w",type=float,default=1e-2,)
    parser.add_argument("--lr_za",type=float,default=1e-1,)
    parser.add_argument("--lr_a",type=float,default=1e-4,)
    parser.add_argument("--lr_rw",type=float,default=1e-3,)
    parser.add_argument("--smooth_type",type=str,default="weight-aware",)
    return parser

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def torch2hwcuint8(x, clip=False):
    if clip:
        x = torch.clamp(x, -1, 1)
    x = (x + 1.0) / 2.0
    return x

def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas

class Diffusion(object):
    def __init__(self, args, config, device=None):
        self.args = args
        self.config = config
        self.model = Model(self.config)
        self.num_timesteps = args.timesteps
        if device is None:
            device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
        self.device = device

        self.model_var_type = config.model.var_type
        betas = get_beta_schedule(
            beta_schedule=config.diffusion.beta_schedule,
            beta_start=config.diffusion.beta_start,
            beta_end=config.diffusion.beta_end,
            num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
        )
        self.betas = torch.from_numpy(betas).float()
        self.betas = self.betas.to(self.device)
        betas = self.betas
        self.num_diffusion_timesteps = betas.shape[0]

        alphas = 1.0 - betas
        alphas_cumprod = alphas.cumprod(dim=0)
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
        )
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        if self.model_var_type == "fixedlarge":
            self.logvar = betas.log()
            # torch.cat(
            # [posterior_variance[1:2], betas[1:]], dim=0).log()
        elif self.model_var_type == "fixedsmall":
            self.logvar = posterior_variance.clamp(min=1e-20).log()

        if self.args.skip_type == "uniform":
            skip = self.num_diffusion_timesteps // self.args.timesteps
            seq = range(0, self.num_diffusion_timesteps, skip)
            self.seq = seq
        elif self.args.skip_type == "quad":
            seq = (
                np.linspace(
                    0, np.sqrt(self.num_diffusion_timesteps * 0.8), self.args.timesteps
                )
                ** 2
            )
            seq = [int(s) for s in list(seq)]
            self.seq = seq
        else:
            raise NotImplementedError

    def QModel(self):
        model = self.model
        # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion
        if self.config.data.dataset == "CIFAR10":
            name = "cifar10"
        elif self.config.data.dataset == "LSUN":
            name = f"lsun_{self.config.data.category}"
        else:
            raise ValueError
        ckpt = get_ckpt_path(f"ema_{name}")
        logger.info("Loading checkpoint {}".format(ckpt))
        model.load_state_dict(torch.load(ckpt, map_location=self.device))

        model.to(self.device)
        model.eval()

        return model

    def sample_fid(self, model):
        config = self.config
        # img_id = len(glob.glob(f"{self.args.image_folder}/*"))
        img_id = 0
        logger.info(f"starting from image {img_id}")
        total_n_samples = self.args.max_images
        n_rounds = math.ceil((total_n_samples - img_id) / config.sampling.batch_size)

        torch.manual_seed(self.args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.args.seed)
        with torch.no_grad():
            for i in tqdm(
                range(n_rounds), desc="Generating image samples for FID evaluation."
            ):
                n = config.sampling.batch_size
                x = torch.randn(
                    n,
                    config.data.channels,
                    config.data.image_size,
                    config.data.image_size,
                    device=self.device,
                )

                with amp.autocast(enabled=False):
                    x = self.sample_image(x, model)
                x = inverse_data_transform(config, x)

                if img_id + x.shape[0] > self.args.max_images:
                    assert(i == n_rounds - 1)
                    n = self.args.max_images - img_id
                for i in range(n):
                    tvu.save_image(
                        x[i], os.path.join(self.args.image_folder, f"{img_id}.png")
                    )
                    img_id += 1
                torch.cuda.empty_cache()

    def ddim_generalized_steps(self, model, x, seq):
        betas = self.betas
        xs = generalized_steps(
            x, seq, model, betas, eta=self.args.eta, args=self.args)
        x = xs[0][:-1]
        return x

    def sample_image(self, x, model, last=True):
        seq = self.seq
        betas = self.betas
        xs = generalized_steps(
            x, seq, model, betas, eta=self.args.eta, args=self.args)
        x = xs
        if last:
            x = x[0][-1]
            # x = x[0][40]
        return x

def block_train_w(q_unet, device, diffusion, kwargs, all_cali_data, all_t, all_cali_t):
    recon_qnn = recon_Qmodel(diffusion.args, q_unet, kwargs)
    # recon_qnn = recon_layer_Qmodel(diffusion.args, q_unet, kwargs)

    all_cali_data = torch.cat(all_cali_data)
    all_t = torch.cat(all_t)
    all_cali_t = torch.cat(all_cali_t)
    idx = torch.randperm(len(all_cali_data))[:5120]
    cali_data = all_cali_data[idx]
    t = all_t[idx]
    cali_t = all_cali_t[idx]
    q_unet.block_count = 0
    '''weight'''
    kwargs['cali_data'] = (cali_data, t)
    kwargs['cali_t'] = cali_t
    recon_qnn.kwargs = kwargs
    recon_qnn.down_name = None
    q_unet.set_steps_state(is_mix_steps=True)
    q_unet = recon_qnn.w_recon()
    q_unet.set_steps_state(is_mix_steps=False)
    torch.cuda.empty_cache()

if __name__ == "__main__":
    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    parser = get_parser()
    args = parser.parse_args()
    # parse config file
    with open(args.config, "r") as f:
        config = yaml.safe_load(f)
    config = dict2namespace(config)
    # fix random seed
    seed_everything(args.seed)

    # setup logger
    logdir = os.path.join(args.logdir, "samples", now)
    os.makedirs(logdir)
    log_path = os.path.join(logdir, "run.log")
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(log_path),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger(__name__)

    logger.info(75 * "=")
    logger.info(f"Host {os.uname()[1]}")
    logger.info("logging to:")
    imglogdir = os.path.join(logdir, "img")
    args.image_folder = imglogdir
    args.image_folder = "./image"
    os.makedirs(args.image_folder)
    logger.info(logdir)
    logger.info(75 * "=")

    diffusion = Diffusion(args, config)
    unet = diffusion.QModel()

    wq_params = {'n_bits': args.weight_bit, 'symmetric': False, 'channel_wise': True, 'scale_method': 'max'}
    aq_params = {'n_bits': args.act_bit, 'symmetric': args.a_sym, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param': args.quant_act, "prob": 1.0, "num_timesteps": 100, "smooth_type": args.smooth_type} 

    q_unet = QAModel(unet, args, wq_params=wq_params, aq_params=aq_params)
    layer_list = None

    if args.qat:
        print("Setting the first and the last layer to 8-bit")
        q_unet.set_first_last_layer_to_8bit()
        q_unet.set_quant_state(False, False)

        if os.path.exists('./Cifar100_Cali_data.pth'):
            samples = torch.load('./Cifar100_Cali_data.pth')
        else:
            loop_fn = (cali_generalized_steps)
            seq = diffusion.seq
            samples = []
            for iters in tqdm(range(5), desc="get train set"):
                shape = (args.batch_samples, 3, config.data.image_size, config.data.image_size)
                img = torch.randn(*shape, device=args.device) 
                for now_rt, sample_t in enumerate(
                        loop_fn(
                            model=q_unet.model, seq=seq, x=img, b=diffusion.betas, eta=diffusion.args.eta, args=diffusion.args
                        )
                    ):
                    if len(seq)==now_rt:
                        all_sample = sample_t[:-1]
                        break
                samples.append(all_sample)
                torch.cuda.empty_cache()
            torch.save(samples, os.path.join('./', "Cifar100_Cali_data.pth"))

        all_samples = []
        for t_sample in range(args.timesteps):
            t_samples = torch.cat([sample[t_sample].cpu() for sample in samples])
            all_samples.append(t_samples)
        del(samples)
        torch.cuda.empty_cache()

        seq = diffusion.seq
        all_cali_data = []
        all_t = []
        all_cali_t = []
        for now_rt, sample_t in enumerate(all_samples):
            idx = torch.randperm(sample_t.size(0))[:256]
            cali_data = sample_t[idx]
            t = torch.full((cali_data.size(0),), seq[(len(seq)-1)-now_rt])
            cali_t = torch.full_like(t, now_rt)
            all_cali_data.append(cali_data)
            all_t.append(t)
            all_cali_t.append(cali_t)
        del(all_samples)
        
        if args.split == True:
            q_unet.model.config.split_shortcut = True

        '''smoothquant normal'''
        cali_data = torch.cat(all_cali_data)
        t = torch.cat(all_t)
        idx = torch.randperm(len(cali_data))[:5120]
        cali_data = cali_data[idx]
        t = t[idx]
        cali_data = (cali_data, t)
        set_smooth_quantize_params(q_unet, cali_data, layer_list=layer_list)
        q_unet.set_smooth_state(set_smooth_weight=True, set_smooth_training=False)

        # Init scale_w
        set_weight_quantize_params(q_unet, (all_cali_data[0], all_t[0]))
        '''Init scale_a'''
        set_act_quantize_params(q_unet, all_cali_data, all_t)
        torch.cuda.empty_cache()

        '''block-wise training For other layers'''
        kwargs = dict(iters=5000,
                        act_quant=True, 
                        weight_quant=True, 
                        asym=True,
                        opt_mode='mse', 
                        lr_a=args.lr_a,
                        lr_w=args.lr_w,
                        lr_rw=args.lr_rw,
                        lr_za=args.lr_za,
                        lr_smooth=1e-7,
                        p=2.0,
                        weight=0.01,
                        b_range=(20,2), 
                        warmup=0.2,
                        batch_size=32,
                        batch_size1=32,
                        input_prob=1.0,
                        recon_w=True,
                        recon_a=True,
                        recon_rw=True,
                        recon_smooth=False,
                        keep_gpu=False,
                        )
        q_unet.set_quant_state(weight_quant=True, act_quant=args.quant_act)
        block_train_w(q_unet, 'cuda', diffusion, kwargs, all_cali_data, all_t, all_cali_t)
        q_unet.set_quant_state(weight_quant=True, act_quant=args.quant_act)

    diffusion.sample_fid(q_unet)
    print("sample down!")

