"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""

import argparse
import os
from pathlib import Path
import numpy as np
import torch
import logging
from PIL import Image
from pytorch_lightning import seed_everything
from tqdm import tqdm
import math
import gc

from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from utils.download import find_model
from models.models import DiT_models
from utils.logger_setup import create_logger
from glob import glob
from copy import deepcopy

from qdit.quant import *
from qdit.outlier import *
from qdit.datautils import *
from collections import defaultdict
from qdit.modelutils import quantize_model, quantize_model_gptq, add_act_quant_wrapper, \
    replace_model, update_weights_rotation, update_weights_scale, get_hadamard_matrix, set_calib_state, get_update_hadamard_matrix_fc2


from tqdm import tqdm
import torch
import torch.nn as nn
from torch.cuda.amp import autocast

from utils.utils import *
from optimize.train import optimize_rotation_matrix
from evolution.evolution import evolution_search

class StopForwardException(Exception):
    """
    Used to throw and catch an exception to stop traversing the graph
    """
    pass

class DataSaverHook:
    """
    Forward hook that stores the input and output of a block
    """
    def __init__(self, store_input=False, store_output=False, stop_forward=False):
        self.store_input = store_input
        self.store_output = store_output
        self.stop_forward = stop_forward

        self.input_store = None
        self.output_store = None

    def __call__(self, module, input_batch, output_batch):
        if self.store_input:
            self.input_store = input_batch
        if self.store_output:
            self.output_store = output_batch
        if self.stop_forward:
            raise StopForwardException
        

def validate_model(args, model, diffusion, vae):
    seed_everything(args.seed)
    device = next(model.parameters()).device
    using_cfg = args.cfg_scale > 1.0
    # Labels to condition the model with (feel free to change):
    class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
    # Create sampling noise:
    n = len(class_labels)
    z = torch.randn(n, 4, model.input_size, model.input_size, device=device)

    y = torch.tensor(class_labels, device=device)
    # Setup classifier-free guidance:
    if using_cfg:
        z = torch.cat([z, z], 0)
        y_null = torch.tensor([1000] * n, device=device)
        y = torch.cat([y, y_null], 0)
        model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
        # sample_fn = model.forward_with_cfg
    else:
        model_kwargs = dict(y=y)
        # sample_fn = model.forward
    z = z.half()
    model_kwargs['load_dir'] = load_dir
    with autocast():
        samples = diffusion.p_sample_loop(
                model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
            )
    if using_cfg:
        samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
    samples = vae.decode(samples / 0.18215).sample
    # Save and display images:
    save_image(samples, f'sample.png', nrow=4, normalize=True, value_range=(-1, 1))
    print("Finish validating samples!")
    
def create_npz_from_sample_folder(sample_dir, num=50_000):
    """
    Builds a single .npz file from a folder of .png samples.
    """
    samples = []
    for i in tqdm(range(num), desc="Building .npz file from samples"):
        sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
        sample_np = np.asarray(sample_pil).astype(np.uint8)
        samples.append(sample_np)
    samples = np.stack(samples)
    assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
    npz_path = f"{sample_dir}.npz"
    np.savez(npz_path, arr_0=samples)
    print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
    return npz_path

def sample_fid(args, model, diffusion, vae):
    # Create folder to save samples:
    seed_everything(args.seed)
    device = next(model.parameters()).device
    using_cfg = args.cfg_scale > 1.0
    model_string_name = args.model.replace("/", "-")
    ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
    folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-" \
                  f"cfg-{args.cfg_scale}-seed-{args.seed}"
    sample_folder_dir = f"{args.experiment_dir}/{folder_name}"
    os.makedirs(sample_folder_dir, exist_ok=True)
    print(f"Saving .png samples at {sample_folder_dir}")

    # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
    n = args.batch_size
    # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
    total_samples = int(math.ceil(args.num_fid_samples / n) * n)
    print(f"Total number of images that will be sampled: {total_samples}")
    iterations = int(total_samples // n)
    pbar = range(iterations)
    pbar = tqdm(pbar)
    total = 0
    for _ in pbar:
        # Sample inputs:
        z = torch.randn(n, model.in_channels, model.input_size, model.input_size, device=device)
        y = torch.randint(0, args.num_classes, (n,), device=device)

        # Setup classifier-free guidance:
        if using_cfg:
            z = torch.cat([z, z], 0)
            y_null = torch.tensor([1000] * n, device=device)
            y = torch.cat([y, y_null], 0)
            model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
        else:
            model_kwargs = dict(y=y)

        z = z.half()
        with autocast():
            samples = diffusion.p_sample_loop(
                model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
            )
        if using_cfg:
            samples, _ = samples.chunk(2, dim=0)  # Remove null class samples

        samples = vae.decode(samples / 0.18215).sample
        samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()

        # Save samples to disk as individual .png files
        for i, sample in enumerate(samples):
            index = i + total
            Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
        total += n

    create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
    print("Done.")


def main():
    args = create_argparser().parse_args()
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
    device = f"cuda:{args.cuda}"
    print(f'device: {device}')
    seed_everything(0)

    # Setup an experiment folder:
    os.makedirs(args.results_dir, exist_ok=True)  # Make results folder (holds all experiment subfolders)
    experiment_index = len(glob(f"{args.results_dir}/*"))
    # quant_method = "qdit"
    quant_name = args.quant_name
    quant_string_name = f"{quant_name}_w{args.wbits}a{args.abits}_w-group-{args.weight_group_size}_{args.image_size}x{args.image_size}"
    experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{quant_string_name}"  # Create an experiment folder
    args.experiment_dir = experiment_dir
    args.quant_string_name = quant_string_name
    os.makedirs(experiment_dir, exist_ok=True)
    create_logger(experiment_dir)
    logging.info(f"Experiment directory created at {experiment_dir}")
    logging.info(f"""wbits: {args.wbits}, abits: {args.abits}, w_sym: {args.w_sym}, a_sym: {args.a_sym},
                 weight_group_size: {args.weight_group_size}, act_group_size: {args.act_group_size},
                 quant_method: {args.quant_method}, use_gptq: {args.use_gptq}, static: {args.static},
                 image_size: {args.image_size}, cfg_scale: {args.cfg_scale}, quant_string_name: {args.quant_string_name},
                 epochs: {args.epochs}, lr: {args.learning_rate}, bata: {args.beta}, topk: {args.topk}, alpha: {args.alpha}""")
    
    # Load model:
    latent_size = args.image_size // 8
    model = DiT_models[args.model](
        input_size=latent_size,
        num_classes=args.num_classes
    ).to(device)
    ckpt_path = f'path/to/your/DiT-XL-2-{args.image_size}x{args.image_size}.pt'
    state_dict = find_model(ckpt_path)
    model.load_state_dict(state_dict)
    model.eval()  # important!
    diffusion = create_diffusion(str(args.num_sampling_steps))
    args.weight_group_size = eval(args.weight_group_size)
    args.act_group_size = eval(args.act_group_size)
    args.weight_group_size = [args.weight_group_size] * len(model.blocks) if isinstance(args.weight_group_size, int) else args.weight_group_size
    args.act_group_size = [args.act_group_size] * len(model.blocks) if isinstance(args.act_group_size, int) else args.act_group_size
    
    logging.info("Replacing Moudle ...")
    model = replace_model(model, args=args)

    logging.info(f"Sampling data for calibration")
    sample_data = torch.load(args.calib_data_path, map_location=device)
    cali_data = get_train_samples(args.calib_n, sample_data)
    del(sample_data)
    gc.collect()
    logging.info(f"Calibration data shape: {cali_data[0].shape} {cali_data[1].shape} {cali_data[2].shape}")
    cali_xs, cali_ts, cali_ys = cali_data
    timesteps = [cali_ts[args.calib_n*i] for i in range(100)]
    print('cali_ts: ', cali_ts)
    print('cali_ys: ', cali_ys)
    print('timesteps: ', timesteps)

    logging.info("Grouping Timesteps ...")
    group_partitions_dict = group_timesteps(cali_xs.to(device), cali_ts.to(device), cali_ys.to(device), timesteps, model.to(device), args)
    layers_select_timestep = select_timesteps(group_partitions_dict, timesteps)

    logging.info("Getting Hadamard Matrix ...")
    model.to(device)
    model = get_hadamard_matrix(model)
    
    logging.info("Optimizing Hadamard Matrix ...")
    model = set_calib_state(model, False)
    cached_path = os.path.join(args.experiment_dir, 'tmp_cached/')
    if not os.path.exists(cached_path):
        os.makedirs(cached_path)
    n_iterations = args.calib_n // args.calib_batch_size

    args_i = deepcopy(args)
    args_i.weight_group_size = args.weight_group_size[0]
    args_i.act_group_size = 0
    args_i.quant_method = 'max'
    act_quant = Quantizer(args=deepcopy(args_i))
    act_quant.configure(
        partial(quantize_activation_wrapper, args=args_i),
        None
    )

    for block_index in tqdm(range(len(model.blocks))):
        m = model.blocks[block_index]
        print('\n' + '-'*50 + ' Start block %d ' % block_index + '-'*50)
        
        print('-'*30 + ' Block %d Optimization for qkv_rotation Matrix ' % block_index + '-'*30)
        v = f'blocks.{block_index}.attn.qkv'
        timesteps_list = layers_select_timestep[v]
        cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list, median_timestep_index = get_layer_calib_data(cali_xs, cali_ts, cali_ys, timesteps_list, args, median_timestep=515)
        print('median_timestep_index: ', median_timestep_index)
        batch_count = 0
        ins_list, outs_list = [], []
        for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list):
            data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=True)
            handle = m.attn.qkv.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            ins = data_saver.input_store[0].detach()
            outs = data_saver.output_store.detach()
            ins_list.append(ins)
            outs_list.append(outs)
            batch_count += 1
            
        m.attn.qkv_rotation_matrix = nn.Parameter(m.attn.qkv_rotation_matrix.to(torch.float32).to(device))
        prefix = f'Block {block_index}\'s qkv '
        optimize_rotation_matrix(m.attn.qkv_rotation_matrix, m.attn.qkv.weight.clone(), m.attn.qkv.bias, act_quant, 
                                 cached_path, n_iterations, args, prefix, median_timestep_index, ins_list, outs_list)
        del ins_list, outs_list


        print('-'*30 + ' Block %d Optimization for proj_rotation Matrix ' % block_index + '-'*30)
        v = f'blocks.{block_index}.attn.proj'
        timesteps_list = layers_select_timestep[v]
        cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list, median_timestep_index = get_layer_calib_data(cali_xs, cali_ts, cali_ys, timesteps_list, args, median_timestep=515)
        print('median_timestep_index: ', median_timestep_index)
        batch_count = 0
        ins_list, outs_list = [], []
        for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list):
            data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=True)
            handle = m.attn.proj.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            ins = data_saver.input_store[0].detach()
            outs = data_saver.output_store.detach()
            ins_list.append(ins)
            outs_list.append(outs)
            batch_count += 1
            
        m.attn.proj_rotation_matrix = nn.Parameter(m.attn.proj_rotation_matrix.to(torch.float32).to(device))
        prefix = f'Block {block_index}\'s proj '
        optimize_rotation_matrix(m.attn.proj_rotation_matrix, m.attn.proj.weight.clone(), m.attn.proj.bias, act_quant, 
                                 cached_path, n_iterations, args, prefix, median_timestep_index, ins_list, outs_list)
        del ins_list, outs_list


        print('-'*30 + ' Block %d Optimization for fc1_rotation Matrix ' % block_index + '-'*30)
        v = f'blocks.{block_index}.mlp.fc1'
        timesteps_list = layers_select_timestep[v]
        cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list, median_timestep_index = get_layer_calib_data(cali_xs, cali_ts, cali_ys, timesteps_list, args, median_timestep=515)
        print('median_timestep_index: ', median_timestep_index)
        batch_count = 0
        ins_list, outs_list = [], []
        for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list):
            data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=True)
            handle = m.mlp.fc1.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            ins = data_saver.input_store[0].detach()
            outs = data_saver.output_store.detach()
            ins_list.append(ins)
            outs_list.append(outs)
            batch_count += 1
            
        m.mlp.fc1_rotation_matrix = nn.Parameter(m.mlp.fc1_rotation_matrix.to(torch.float32).to(device))
        prefix = f'Block {block_index}\'s fc1 '
        optimize_rotation_matrix(m.mlp.fc1_rotation_matrix, m.mlp.fc1.weight.clone(), m.mlp.fc1.bias, act_quant, 
                                 cached_path, n_iterations, args, prefix, median_timestep_index, ins_list, outs_list)
        del ins_list, outs_list

        model.blocks[block_index] = m
        torch.cuda.empty_cache()
    del group_partitions_dict, layers_select_timestep

    logging.info("Updating Weights Rotation ...")
    model = update_weights_rotation(model)
    
    logging.info("Updating Weights Rotation fc2...")
    model = get_update_hadamard_matrix_fc2(model)

    logging.info("Saving Rotation Matrix...")
    save_rotation_matrix(model, args.experiment_dir)

    logging.info("Calibrating Activation Ranges ...")
    model = set_calib_state(model, True)
    n_iterations = args.calib_n // args.calib_batch_size

    cali_xs_batch_list = []
    cali_ts_batch_list = []
    cali_ys_batch_list = []
    timesteps = [timesteps[i] for i in range(0, 100, 100 // args.n_timesteps_group)]
    for i in range(n_iterations):
        inds = []
        for t in timesteps:
            idx = torch.where(cali_ts == t)[0][i*args.calib_batch_size:(i+1)*args.calib_batch_size]
            inds.extend(idx.tolist())
        # print('inds: ', inds)
            
        # rearrange data
        cali_xs_batch = cali_xs[inds]
        cali_ts_batch = cali_ts[inds]
        cali_ys_batch = cali_ys[inds]
        normal_index = torch.where(cali_ys_batch != 1000)[0]
        null_index = torch.where(cali_ys_batch == 1000)[0]
        # print('normal_index: ', normal_index)
        # print('null_index: ', null_index)
        
        cali_xs_batch = torch.cat([cali_xs_batch[normal_index], cali_xs_batch[null_index]], 0)
        cali_ts_batch = torch.cat([cali_ts_batch[normal_index], cali_ts_batch[null_index]], 0)
        cali_ys_batch = torch.cat([cali_ys_batch[normal_index], cali_ys_batch[null_index]], 0)
        
        cali_xs_batch_list.append(cali_xs_batch)
        cali_ts_batch_list.append(cali_ts_batch)
        cali_ys_batch_list.append(cali_ys_batch)

    for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list):
        with torch.no_grad():
            _ = model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)

    logging.info("Updating Weights Scaling ...")
    model = update_weights_scale(model)

    logging.info("Saving Scaling Matrix...")
    save_scale_matrix(model, args.experiment_dir)

    fp_model = DiT_models[args.model](
        input_size=latent_size,
        num_classes=args.num_classes
    ).to(device)
    ckpt_path = f'path/to/your/DiT-XL-2-{args.image_size}x{args.image_size}.pt'
    state_dict = find_model(ckpt_path)
    fp_model.load_state_dict(state_dict)
    fp_model.eval()  # important!

    args_i.act_group_size = args.act_group_size[0]
    # args_i.abits = 4
    act_quant = Quantizer(args=deepcopy(args_i))
    act_quant.configure(
        partial(quantize_activation_wrapper, args=args_i),
        None
    )

    reorder_index_dict = {}
    evo_batches = len(cali_xs_batch_list) // 8
    for block_index in tqdm(range(len(model.blocks))):
    # for block_index in [0]:
        m = model.blocks[block_index]
        fp_m = fp_model.blocks[block_index]
        print('\n' + '-'*50 + ' Start block %d ' % block_index + '-'*50)
        
        print('-'*30 + ' Block %d Search for qkv permutation ' % block_index + '-'*30)
        ins_list, outs_list = [], []
        for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list[:evo_batches], cali_ts_batch_list[:evo_batches], cali_ys_batch_list[:evo_batches]):
            # quant_model
            data_saver = DataSaverHook(store_input=True, store_output=False, stop_forward=True)
            handle = m.attn.qkv.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            ins = data_saver.input_store[0].detach()
            ins_list.append(ins)

            # fp_model
            data_saver = DataSaverHook(store_input=False, store_output=True, stop_forward=True)
            handle = fp_m.attn.qkv.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = fp_model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            outs = data_saver.output_store.detach()
            outs_list.append(outs)
            
        qkv_reoder_index = evolution_search(ins_list, outs_list, act_quant, m.attn.qkv.weight.clone(), m.attn.qkv.bias, args, f'blocks.{block_index}.attn.qkv', device)
        reorder_index_dict[f'blocks.{block_index}.attn.qkv'] = qkv_reoder_index
        del ins_list, outs_list


        print('-'*30 + ' Block %d Search for proj permutation ' % block_index + '-'*30)
        ins_list, outs_list = [], []
        for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list[:evo_batches], cali_ts_batch_list[:evo_batches], cali_ys_batch_list[:evo_batches]):
            # quant_model
            data_saver = DataSaverHook(store_input=True, store_output=False, stop_forward=True)
            handle = m.attn.proj.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            ins = data_saver.input_store[0].detach()
            ins_list.append(ins)

            # fp_model
            data_saver = DataSaverHook(store_input=False, store_output=True, stop_forward=True)
            handle = fp_m.attn.proj.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = fp_model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            outs = data_saver.output_store.detach()
            outs_list.append(outs)
            
        proj_reoder_index = evolution_search(ins_list, outs_list, act_quant, m.attn.proj.weight.clone(), m.attn.proj.bias, args, f'blocks.{block_index}.attn.proj', device)
        reorder_index_dict[f'blocks.{block_index}.attn.proj'] = proj_reoder_index
        del ins_list, outs_list


        print('-'*30 + ' Block %d Search for fc1 permutation ' % block_index + '-'*30)
        ins_list, outs_list = [], []
        for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list[:evo_batches], cali_ts_batch_list[:evo_batches], cali_ys_batch_list[:evo_batches]):
            # quant_model
            data_saver = DataSaverHook(store_input=True, store_output=False, stop_forward=True)
            handle = m.mlp.fc1.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            ins = data_saver.input_store[0].detach()
            ins_list.append(ins)

            # fp_model
            data_saver = DataSaverHook(store_input=False, store_output=True, stop_forward=True)
            handle = fp_m.mlp.fc1.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = fp_model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            outs = data_saver.output_store.detach()
            outs_list.append(outs)
            
        fc1_reoder_index = evolution_search(ins_list, outs_list, act_quant, m.mlp.fc1.weight.clone(), m.mlp.fc1.bias, args, f'blocks.{block_index}.attn.proj', device)
        reorder_index_dict[f'blocks.{block_index}.mlp.fc1'] = fc1_reoder_index
        del ins_list, outs_list


        print('-'*30 + ' Block %d Search for fc2 permutation ' % block_index + '-'*30)
        ins_list, outs_list = [], []
        for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list):
            # quant_model
            data_saver = DataSaverHook(store_input=True, store_output=False, stop_forward=True)
            handle = m.mlp.fc2.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            ins = data_saver.input_store[0].detach()
            ins_list.append(ins)

            # fp_model
            data_saver = DataSaverHook(store_input=False, store_output=True, stop_forward=True)
            handle = fp_m.mlp.fc2.register_forward_hook(data_saver)
            
            try:
                with torch.no_grad():
                    _ = fp_model.forward_with_cfg(cali_xs_batch.to(device), cali_ts_batch.to(device), cali_ys_batch.to(device), args.cfg_scale)
            except StopForwardException:
                pass
            
            handle.remove()
            outs = data_saver.output_store.detach()
            outs_list.append(outs)
            
        fc2_reoder_index = evolution_search(ins_list, outs_list, act_quant, m.mlp.fc2.weight.clone(), m.mlp.fc2.bias, args, f'blocks.{block_index}.attn.proj', device)
        reorder_index_dict[f'blocks.{block_index}.mlp.fc2'] = fc2_reoder_index
        del ins_list, outs_list

        model.blocks[block_index] = m
        torch.cuda.empty_cache()
    
    torch.save(reorder_index_dict, os.path.join(args.experiment_dir, 'reorder_index_dict.pt'))

    for block_index in tqdm(range(len(model.blocks))):
    # for block_index in [0]:
        m = model.blocks[block_index]
        qkv_reorder_index = reorder_index_dict[f'blocks.{block_index}.attn.qkv']
        m.attn.reorder_index_qkv = qkv_reorder_index
        m.attn.qkv.weight = torch.index_select(m.attn.qkv.weight, 1, qkv_reorder_index)

        proj_reorder_index = reorder_index_dict[f'blocks.{block_index}.attn.proj']
        m.attn.reorder_index_proj = proj_reorder_index
        m.attn.proj.weight = torch.index_select(m.attn.proj.weight, 1, proj_reorder_index)

        fc1_reorder_index = reorder_index_dict[f'blocks.{block_index}.mlp.fc1']
        m.mlp.reorder_index_fc1 = fc1_reorder_index
        m.mlp.fc1.weight = torch.index_select(m.mlp.fc1.weight, 1, fc1_reorder_index)

        fc2_reorder_index = reorder_index_dict[f'blocks.{block_index}.mlp.fc2']
        m.mlp.reorder_index_fc2 = fc2_reorder_index
        m.mlp.fc2.weight = torch.index_select(m.mlp.fc2.weight, 1, fc2_reorder_index)

        model.blocks[block_index] = m
        torch.cuda.empty_cache()
    
    del fp_model

    logging.info("Inserting activations quantizers ...")
    if args.static:
        dataloader = get_loader(args.calib_data_path, nsamples=1024, batch_size=16)
        print("Getting activation stats...")
        scales = get_act_scales(
            model, diffusion, dataloader, device, args
        )
    else:
        scales = defaultdict(lambda: None)
    model = add_act_quant_wrapper(model, device=device, args=args, scales=scales)

    logging.info("Quantizing ...")
    if args.use_gptq:
        # dataloader = get_loader(args.calib_data_path, nsamples=256)
        # dataloader = zip(cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list)
        calib_data = (cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list)
        model = quantize_model_gptq(model, device=device, args=args, dataloader=None, calib_data=calib_data)
    else:
        model = quantize_model(model, device=device, args=args)

    logging.info("Finish quant!")
    logging.info(model)

    # generate some sample images
    model.to(device)
    model.half()
    torch.backends.cuda.matmul.allow_tf32 = args.tf32  # True: fast but may lead to some small numerical differences
    torch.set_grad_enabled(False)
    
    vae = AutoencoderKL.from_pretrained("path/to/your/stabilityai_sd-vae-ft-mse").to(device)

    if not args.sample:
        validate_model(args, model, diffusion, vae)
    else:
        sample_fid(args, model, diffusion, vae)

def create_argparser():
    parser = argparse.ArgumentParser()

    # my params
    parser.add_argument(
        '--cuda', type=int, default=0,
    )
    parser.add_argument(
        '--quant_name', type=str, default='qdit',
    )
    parser.add_argument(
        '--sample', action='store_true',
    )
    
    parser.add_argument(
        '--epochs', type=int, default=20,
    )
    parser.add_argument(
        '--learning_rate', type=float, default=1.5,
    )
    parser.add_argument(
        '--beta', type=int, default=50,
    )
    parser.add_argument(
        '--topk', type=int, default=50,
    )
    parser.add_argument(
        '--alpha', type=float, default=1.0,
    )
    
    # calib
    parser.add_argument(
        "--calib_data_path", type=str, default="path/to/your/imagenet_DiT-256_sample32_100steps_allst.pt", help="calibration dataset name"
    )
    parser.add_argument(
        "--calib_n", type=int, default=64, help="number of samples for each timestep for calibration"
    )
    parser.add_argument(
        "--n_timesteps_group", type=int, default=25, help="number of timesteps used for calibration"
    )
    parser.add_argument(
        "--calib_batch_size", type=int, default=2, help="batch size for calibration"
    )

    # quantization parameters
    parser.add_argument(
        '--wbits', type=int, default=16, choices=[2, 3, 4, 5, 6, 8, 16],
        help='#bits to use for quantizing weight; use 16 for evaluating base model.'
    )
    parser.add_argument(
        '--abits', type=int, default=16, choices=[2, 3, 4, 5, 6, 8, 16],
        help='#bits to use for quantizing activation; use 16 for evaluating base model.'
    )
    parser.add_argument(
        '--exponential', action='store_true',
        help='Whether to use exponent-only for weight quantization.'
    )
    parser.add_argument(
        '--quantize_bmm_input', action='store_true',
        help='Whether to perform bmm input activation quantization. Default is not.'
    )
    parser.add_argument(
        '--a_sym', action='store_true',
        help='Whether to perform symmetric quantization. Default is asymmetric.'
    )
    parser.add_argument(
        '--w_sym', action='store_true',
        help='Whether to perform symmetric quantization. Default is asymmetric.'
    )
    parser.add_argument(
        '--static', action='store_true',
        help='Whether to perform static quantization (For activtions). Default is dynamic. (Deprecated in Atom)'
    )
    parser.add_argument(
        '--weight_group_size', type=str,
        help='Group size when quantizing weights. Using 128 as default quantization group.'
    )
    parser.add_argument(
        '--weight_channel_group', type=int, default=1,
        help='Group size of channels that will quantize together. (only for weights now)'
    )
    parser.add_argument(
        '--act_group_size', type=str,
        help='Group size when quantizing activations. Using 128 as default quantization group.'
    )
    parser.add_argument(
        '--tiling', type=int, default=0, choices=[0, 16],
        help='Tile-wise quantization granularity (Deprecated in Atom).'
    )
    parser.add_argument(
        '--percdamp', type=float, default=.01,
        help='Percent of the average Hessian diagonal to use for dampening.'
    )
    parser.add_argument(
        '--use_gptq', action='store_true',
        help='Whether to use GPTQ for weight quantization.'
    )
    parser.add_argument(
        '--quant_method', type=str, default='max', choices=['max', 'mse'],
        help='The method to quantize weight.'
    )
    parser.add_argument(
        '--a_clip_ratio', type=float, default=1.0,
        help='Clip ratio for activation quantization. new_max = max * clip_ratio'
    )
    parser.add_argument(
        '--w_clip_ratio', type=float, default=1.0,
        help='Clip ratio for weight quantization. new_max = max * clip_ratio'
    )
    parser.add_argument(
        '--save_dir', type=str, default='./saved',
        help='Path to store the reordering indices and quantized weights.'
    )
    parser.add_argument(
        '--quant_type', type=str, default='int', choices=['int', 'fp'],
        help='Determine the mapped data format by quant_type + n_bits. e.g. int8, fp4.'
    )
    # Inherited from DiT
    parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
    parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse")
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--cfg-scale", type=float, default=1.5)
    parser.add_argument("--num-sampling-steps", type=int, default=100)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--ckpt", type=str, default='path/to/your/DiT-XL-2-256x256.pt',
                        help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).")
    parser.add_argument("--results-dir", type=str, default="./results")
    parser.add_argument(
        "--save_ckpt", action="store_true", help="choose to save the qnn checkpoint"
    )
    # sample_ddp.py
    parser.add_argument("--tf32", action="store_true",
                        help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
    parser.add_argument("--sample-dir", type=str, default="samples")
    parser.add_argument("--num-fid-samples", type=int, default=50_000)
    return parser


if __name__ == "__main__": 
    main()
