import torch
from typing import Union
from tqdm import tqdm

from quant.quant_layer import QuantModule
from quant.quant_block import BaseQuantBlock, QuantAttnBlock, QuantSMVMatMul, QuantQKMatMul
from quant.quant_model import Quant_Model
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ddim_control import DDIMSampler_control
from ldm.models.diffusion.plms import PLMSSampler

def set_smooth_quantize_params2(
    module: Union[Quant_Model, QuantModule, BaseQuantBlock], layer_list=None,
):
    print(f"set_smooth_quantize_params")
    module.set_quant_state(False, False)
    # layer_list don't smooth
    for m in module.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(False)
    for name, m in module.named_modules():
        if isinstance(m, (QuantModule)):
            for layer_name in layer_list:
                if layer_name in name: 
                    m.smooth_quantizer.set_inited(True)         

def set_smooth_quantize_params(
    module: Union[Quant_Model, QuantModule, BaseQuantBlock],
    cali_data,
    layer_list=None,
    batch_size: int = 512,
):
    print(f"set_smooth_quantize_params")
    module.set_quant_state(False, False)
    # layer_list don't smooth
    for m in module.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(False)

    if layer_list is not None:
        for name, m in module.named_modules():
            if isinstance(m, (QuantModule)):
                for layer_name in layer_list:
                    if layer_name in name: 
                        m.smooth_quantizer.set_inited(True)

    """set or init scales in the smooth_quantizer"""
    if not isinstance(cali_data, (tuple, list)):
        batch_size = min(batch_size, cali_data.size(0))
        with torch.no_grad():
            for i in range(int(cali_data.size(0) / batch_size)):
                module(cali_data[i * batch_size : (i + 1) * batch_size].cuda())
        torch.cuda.empty_cache()
    else:
        batch_size = min(batch_size, cali_data[0].size(0))
        with torch.no_grad():
            for i in range(int(cali_data[0].size(0) / batch_size)):
                module(
                    *[
                        _[i * batch_size : (i + 1) * batch_size].cuda()
                        for _ in cali_data
                    ]
                )
        torch.cuda.empty_cache()
    for m in module.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(True)
            m.set_smooth_training(smooth_is_training=True)

def set_act_quantize_params(
    module: Union[Quant_Model, QuantModule, BaseQuantBlock],
    all_cali_data, 
    all_t,
    batch_size: int = 256,
):
    print(f"set_act_quantize_params")
    module.set_quant_state(False, True)

    """set or init step size and zero point in the activation quantizer"""
    batch_size = min(batch_size, all_cali_data[0].size(0))
    module.set_act_quantize_init(act_init=False)
    with torch.no_grad():
        for time in tqdm(range(len(all_cali_data)), desc="Init scale_a"):
            module.set_time(time)
            for i in range(int(all_cali_data[time].size(0) / batch_size)):
                module(
                    *[
                        _[i * batch_size : (i + 1) * batch_size].cuda()
                        for _ in (all_cali_data[time], all_t[time])
                    ]
                )
            module.set_act_quantize_init(act_init=False)

    torch.cuda.empty_cache()
    module.set_act_quantize_init(act_init=True)

def set_weight_quantize_params(model, cali_data):
    print(f"set_weight_quantize_params")

    model.set_quant_state(True, False)
    for name, module in model.named_modules():
        if isinstance(module, QuantModule):
            module.weight_quantizer.set_inited(False)

    batch_size = 32
    with torch.no_grad():
        model(
            *[
                _[:batch_size].cuda()
                for _ in cali_data
            ]
        )
    torch.cuda.empty_cache()
    for name, module in model.named_modules():
        if isinstance(module, QuantModule):
            if module.split == 0:
                module.weight_quantizer.set_inited(True)
            else:
                module.weight_quantizer.set_inited(True)
                module.weight_quantizer_0.set_inited(True)


def set_smooth_quantize_params_LDM(
    module,
    cali_data,
    args,
    layer_list=None,
    batch_size: int = 64,
):
    print(f"set_smooth_quantize_params")
    module.model.diffusion_model.set_quant_state(False, False)
    # layer_list don't smooth
    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(False)

    if layer_list is not None:
        for name, m in module.model.diffusion_model.named_modules():
            if isinstance(m, (QuantModule)):
                for layer_name in layer_list:
                    if layer_name in name: 
                        m.smooth_quantizer.set_inited(True)

    batch_size = min(batch_size, cali_data[0].size(0))
    shape = [batch_size,
             module.model.diffusion_model.in_channels,
             module.model.diffusion_model.image_size,
             module.model.diffusion_model.image_size]
    ddim = DDIMSampler(module)
    bs = shape[0]
    shape = shape[1:]

    """set or init scales in the smooth_quantizer"""
    with torch.no_grad():
        for i in tqdm(range(int(cali_data[0].size(0) / batch_size)), desc="Inited smooth"):
        # for i in tqdm(range(int(batch_size / batch_size)), desc="Inited activation"):
            sample, intermediates = ddim.sample(args.custom_steps, batch_size=bs, shape=shape, eta=args.eta, verbose=False, 
                                                quant_unet=True, cali_data=[_[i * batch_size : (i + 1) * batch_size].cuda() for _ in cali_data])
    torch.cuda.empty_cache()

    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(True)
            m.set_smooth_training(smooth_is_training=True)

def set_act_quantize_params_LDM(
    module,
    all_cali_data, 
    all_t,
    all_index,
    args,
    batch_size: int = 32,
):
    print(f"set_act_quantize_params")
    module.model.diffusion_model.set_quant_state(False, True)
    """set or init step size and zero point in the activation quantizer"""
    batch_size = min(batch_size, all_cali_data[0].size(0))
    shape = [batch_size,
             module.model.diffusion_model.in_channels,
             module.model.diffusion_model.image_size,
             module.model.diffusion_model.image_size]
    ddim = DDIMSampler(module)
    bs = shape[0]
    shape = shape[1:]

    module.model.diffusion_model.set_act_quantize_init(act_init=False)
    with torch.no_grad():
        for time in tqdm(range(len(all_cali_data)), desc="Init scale_a"):
            module.model.diffusion_model.set_time(time)
            for i in range(int(all_cali_data[time].size(0) / batch_size)):
            # for i in range(int(batch_size / batch_size)):
                sample, intermediates = ddim.sample(args.custom_steps, batch_size=bs, shape=shape, eta=args.eta, verbose=False, 
                                                    quant_unet=True, cali_data=[_[i * batch_size : (i + 1) * batch_size].cuda() for _ in (all_cali_data[time], all_t[time], all_index[time])])
            module.model.diffusion_model.set_act_quantize_init(act_init=False)

    torch.cuda.empty_cache()
    module.model.diffusion_model.set_act_quantize_init(act_init=True)

def set_weight_quantize_params_LDM(model, cali_data, args):
    print(f"set_weight_quantize_params")
    model.model.diffusion_model.set_quant_state(True, False)

    for name, module in model.model.diffusion_model.named_modules():
        if isinstance(module, QuantModule):
            module.weight_quantizer.set_inited(False)

    batch_size = 8
    shape = [batch_size,
             model.model.diffusion_model.in_channels,
             model.model.diffusion_model.image_size,
             model.model.diffusion_model.image_size]
    ddim = DDIMSampler(model)
    bs = shape[0]
    shape = shape[1:]

    with torch.no_grad():
        sample, intermediates = ddim.sample(args.custom_steps, batch_size=bs, shape=shape, eta=args.eta, verbose=False, 
                                            quant_unet=True, cali_data=[_[:batch_size].cuda() for _ in cali_data])
    torch.cuda.empty_cache()

    for name, module in model.named_modules():
        if isinstance(module, QuantModule):
            if module.split == 0:
                module.weight_quantizer.set_inited(True)
            else:
                module.weight_quantizer.set_inited(True)
                module.weight_quantizer_0.set_inited(True)

def set_smooth_quantize_params_Conditional(
    module,
    cali_data,
    args,
    layer_list=None,
    batch_size: int = 64,
):
    # layer_list=["transformer_blocks", "attn2"]
    print(f"set_smooth_quantize_params")
    module.model.diffusion_model.set_quant_state(False, False)
    # layer_list don't smooth
    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(False)

    if layer_list is not None:
        for name, m in module.model.diffusion_model.named_modules():
            if isinstance(m, (QuantModule)):
                for layer_name in layer_list:
                    if layer_name in name: 
                        m.smooth_quantizer.set_inited(True)
    batch_size = min(batch_size, cali_data[0].size(0))
                   
    uc = module.get_learned_conditioning(
            {module.cond_stage_key: torch.tensor(batch_size*[1000]).to(module.device)}
            )
    xc = args.data[:batch_size]
    c = module.get_learned_conditioning({module.cond_stage_key: xc.to(module.device)})
    shape = [3, 64, 64]
    sampler = DDIMSampler_control(module)

    with torch.no_grad():
        for i in tqdm(range(int(cali_data[0].size(0) / batch_size)), desc="Inited smooth"):
            _ = sampler.sample(S=args.ddim_steps,
                                conditioning=c,
                                batch_size=batch_size,
                                shape=shape,
                                verbose=False,
                                unconditional_guidance_scale=args.scale,
                                unconditional_conditioning=uc,
                                eta=args.ddim_eta,
                                quant_unet=True, cali_data=[_[i * batch_size : (i + 1) * batch_size].cuda() for _ in cali_data])
    torch.cuda.empty_cache()

    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(True)
            m.set_smooth_training(smooth_is_training=True)

def set_act_quantize_params_Conditional(
    module,
    all_cali_data, 
    all_t,
    all_index,
    all_cond,
    all_uncond,
    args,
    batch_size: int = 32,
):
    print(f"set_act_quantize_params")
    module.model.diffusion_model.set_quant_state(False, True)

    module.model.diffusion_model.set_act_quantize_init(act_init=False)

    batch_size = min(batch_size, all_cali_data[0].size(0))                  
    uc = module.get_learned_conditioning(
            {module.cond_stage_key: torch.tensor(batch_size*[1000]).to(module.device)}
            )
    xc = args.data[:batch_size]
    c = module.get_learned_conditioning({module.cond_stage_key: xc.to(module.device)})
    shape = [3, 64, 64]
    
    sampler = DDIMSampler_control(module)

    with torch.no_grad():
        for time in tqdm(range(len(all_cali_data)), desc="Init scale_a"):
            module.model.diffusion_model.set_time(time)
            for i in range(int(all_cali_data[time].size(0) / batch_size)):
                _ = sampler.sample(S=args.ddim_steps,
                                    conditioning=c,
                                    batch_size=batch_size,
                                    shape=shape,
                                    verbose=False,
                                    unconditional_guidance_scale=args.scale,
                                    unconditional_conditioning=uc,
                                    eta=args.ddim_eta,
                                    quant_unet=True, cali_data=[_[i * batch_size : (i + 1) * batch_size].cuda() for _ in (all_cali_data[time], all_t[time], all_index[time], all_cond[time], all_uncond[time])])
            module.model.diffusion_model.set_act_quantize_init(act_init=False)

    torch.cuda.empty_cache()
    module.model.diffusion_model.set_act_quantize_init(act_init=True)

def set_weight_quantize_params_Conditional(model, cali_data, args):
    print(f"set_weight_quantize_params")
    model.model.diffusion_model.set_quant_state(True, False)

    for name, module in model.model.diffusion_model.named_modules():
        if isinstance(module, QuantModule):
            module.weight_quantizer.set_inited(False)

    batch_size = 2                      
    uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(batch_size*[1000]).to(model.device)}
            )
    xc = args.data[:batch_size]
    c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
    shape = [3, 64, 64]
    sampler = DDIMSampler_control(model)

    with torch.no_grad():
        _ = sampler.sample(S=args.ddim_steps,
                            conditioning=c,
                            batch_size=batch_size,
                            shape=shape,
                            verbose=False,
                            unconditional_guidance_scale=args.scale,
                            unconditional_conditioning=uc,
                            eta=args.ddim_eta,
                            quant_unet=True, cali_data=[_[:batch_size].cuda() for _ in cali_data])
    torch.cuda.empty_cache()

    for name, module in model.named_modules():
        if isinstance(module, QuantModule):
            if module.split == 0:
                module.weight_quantizer.set_inited(True)
            else:
                module.weight_quantizer.set_inited(True)
                module.weight_quantizer_0.set_inited(True)

def set_smooth_quantize_params_Stable(
    module,
    cali_data,
    args,
    layer_list=None,
    batch_size: int = 8,
):
    # layer_list=["transformer_blocks", "attn2"]
    print(f"set_smooth_quantize_params")
    module.model.diffusion_model.set_quant_state(False, False)
    # layer_list don't smooth
    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(False)

    if layer_list is not None:
        for name, m in module.model.diffusion_model.named_modules():
            if isinstance(m, (QuantModule)):
                for layer_name in layer_list:
                    if layer_name in name: 
                        m.smooth_quantizer.set_inited(True)

    batch_size = min(batch_size, cali_data[0].size(0))
    uc = module.get_learned_conditioning(batch_size * [""])
    prompts = args.list_prompts[:batch_size]
    c = module.get_learned_conditioning(prompts)
    shape = [args.C, args.H // args.f, args.W // args.f]
    start_code = None
    if args.plms:
        sampler = PLMSSampler(module)
    else:
        sampler = DDIMSampler_control(module)

    with torch.no_grad():
        for i in tqdm(range(int(cali_data[0].size(0) / batch_size)), desc="Init scale_smooth"):
        # for i in tqdm(range(int(batch_size / batch_size)), desc="Init scale_smooth"):
            _ = sampler.sample(S=args.ddim_steps,
                                conditioning=c,
                                batch_size=batch_size,
                                shape=shape,
                                verbose=False,
                                unconditional_guidance_scale=args.scale,
                                unconditional_conditioning=uc,
                                eta=args.ddim_eta,
                                x_T=start_code,
                                quant_unet=True, cali_data=[_[i * batch_size : (i + 1) * batch_size].cuda() for _ in cali_data])
    torch.cuda.empty_cache()

    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(True)
            m.set_smooth_training(smooth_is_training=True)

def set_act_quantize_params_Stable(
    module,
    all_cali_data, 
    all_t,
    all_index,
    all_cond,
    all_uncond,
    all_ts_next,
    args,
    batch_size: int = 2,
):
    print(f"set_act_quantize_params")
    module.model.diffusion_model.set_quant_state(False, True)
    module.model.diffusion_model.set_act_quantize_init(act_init=False)

    """set or init step size and zero point in the activation quantizer"""
    batch_size = min(batch_size, all_cali_data[0].size(0))    
    uc = module.get_learned_conditioning(batch_size * [""])
    prompts = args.list_prompts[:batch_size]
    c = module.get_learned_conditioning(prompts)
    shape = [args.C, args.H // args.f, args.W // args.f]
    start_code = None
    if args.plms:
        sampler = PLMSSampler(module)
    else:
        sampler = DDIMSampler_control(module)

    with torch.no_grad():
        for time in tqdm(range(len(all_cali_data)), desc="Init scale_a"):
            module.model.diffusion_model.set_time(time)
            for i in range(int(all_cali_data[time].size(0) / batch_size)):
                _ = sampler.sample(S=args.ddim_steps,
                                    conditioning=c,
                                    batch_size=batch_size,
                                    shape=shape,
                                    verbose=False,
                                    unconditional_guidance_scale=args.scale,
                                    unconditional_conditioning=uc,
                                    eta=args.ddim_eta,
                                    x_T=start_code,
                                    quant_unet=True, cali_data=[_[i * batch_size : (i + 1) * batch_size].cuda() for _ in (all_cali_data[time], all_t[time], all_index[time], all_cond[time], all_uncond[time], all_ts_next[time])])
            module.model.diffusion_model.set_act_quantize_init(act_init=False)

    torch.cuda.empty_cache()
    module.model.diffusion_model.set_act_quantize_init(act_init=True)

def set_weight_quantize_params_Stable(model, cali_data, args):
    print(f"set_weight_quantize_params")
    model.model.diffusion_model.set_quant_state(True, False)

    for name, module in model.model.diffusion_model.named_modules():
        if isinstance(module, QuantModule):
            module.weight_quantizer.set_inited(False)

    batch_size = 2
    uc = model.get_learned_conditioning(batch_size * [""])
    prompts = args.list_prompts[:batch_size]
    c = model.get_learned_conditioning(prompts)
    shape = [args.C, args.H // args.f, args.W // args.f]
    start_code = None
    if args.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler_control(model)

    with torch.no_grad():
        _ = sampler.sample(S=args.ddim_steps,
                            conditioning=c,
                            batch_size=batch_size,
                            shape=shape,
                            verbose=False,
                            unconditional_guidance_scale=args.scale,
                            unconditional_conditioning=uc,
                            eta=args.ddim_eta,
                            x_T=start_code,
                            quant_unet=True, cali_data=[_[:batch_size].cuda() for _ in cali_data])
    torch.cuda.empty_cache()

    for name, module in model.named_modules():
        if isinstance(module, QuantModule):
            if module.split == 0:
                module.weight_quantizer.set_inited(True)
            else:
                module.weight_quantizer.set_inited(True)
                module.weight_quantizer_0.set_inited(True)

def set_smooth_quantize_params_Conditional_plms(
    module,
    cali_data,
    args,
    layer_list=None,
    batch_size: int = 64,
):
    print(f"set_smooth_quantize_params")
    module.model.diffusion_model.set_quant_state(False, False)
    # layer_list don't smooth
    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(False)

    if layer_list is not None:
        for name, m in module.model.diffusion_model.named_modules():
            if isinstance(m, (QuantModule)):
                for layer_name in layer_list:
                    if layer_name in name: 
                        m.smooth_quantizer.set_inited(True)
    batch_size = min(batch_size, cali_data[0].size(0))
                   
    uc = module.get_learned_conditioning(
            {module.cond_stage_key: torch.tensor(batch_size*[1000]).to(module.device)}
            )
    xc = args.data[:batch_size]
    c = module.get_learned_conditioning({module.cond_stage_key: xc.to(module.device)})
    shape = [3, 64, 64]
    sampler = PLMSSampler(module)

    with torch.no_grad():
        for i in tqdm(range(int(cali_data[0].size(0) / batch_size)), desc="Inited smooth"):
            _ = sampler.sample(S=args.ddim_steps,
                                conditioning=c,
                                batch_size=batch_size,
                                shape=shape,
                                verbose=False,
                                unconditional_guidance_scale=args.scale,
                                unconditional_conditioning=uc,
                                eta=args.ddim_eta,
                                quant_unet=True, cali_data=[_[i * batch_size : (i + 1) * batch_size].cuda() for _ in cali_data])
    torch.cuda.empty_cache()


    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(True)
            m.set_smooth_training(smooth_is_training=True)

def set_act_quantize_params_Conditional_plms(
    module,
    all_cali_data, 
    all_t,
    all_index,
    all_cond,
    all_uncond,
    all_ts_next,
    args,
    batch_size: int = 32,
):
    print(f"set_act_quantize_params")
    module.model.diffusion_model.set_quant_state(False, True)

    module.model.diffusion_model.set_act_quantize_init(act_init=False)

    batch_size = min(batch_size, all_cali_data[0].size(0))                  
    uc = module.get_learned_conditioning(
            {module.cond_stage_key: torch.tensor(batch_size*[1000]).to(module.device)}
            )
    xc = args.data[:batch_size]
    c = module.get_learned_conditioning({module.cond_stage_key: xc.to(module.device)})
    shape = [3, 64, 64]
    
    sampler = PLMSSampler(module)

    with torch.no_grad():
        for time in tqdm(range(len(all_cali_data)), desc="Init scale_a"):
            module.model.diffusion_model.set_time(time)
            for i in range(int(all_cali_data[time].size(0) / batch_size)):
                _ = sampler.sample(S=args.ddim_steps,
                                    conditioning=c,
                                    batch_size=batch_size,
                                    shape=shape,
                                    verbose=False,
                                    unconditional_guidance_scale=args.scale,
                                    unconditional_conditioning=uc,
                                    eta=args.ddim_eta,
                                    quant_unet=True, cali_data=[_[i * batch_size : (i + 1) * batch_size].cuda() for _ in (all_cali_data[time], all_t[time], all_index[time], all_cond[time], all_uncond[time], all_ts_next[time])])
            module.model.diffusion_model.set_act_quantize_init(act_init=False)

    torch.cuda.empty_cache()
    module.model.diffusion_model.set_act_quantize_init(act_init=True)

def set_weight_quantize_params_Conditional_plms(model, cali_data, args):
    print(f"set_weight_quantize_params")
    model.model.diffusion_model.set_quant_state(True, False)

    for name, module in model.model.diffusion_model.named_modules():
        if isinstance(module, QuantModule):
            module.weight_quantizer.set_inited(False)

    batch_size = 2                      
    uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(batch_size*[1000]).to(model.device)}
            )
    xc = args.data[:batch_size]
    c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
    shape = [3, 64, 64]
    sampler = PLMSSampler(model)

    with torch.no_grad():
        _ = sampler.sample(S=args.ddim_steps,
                            conditioning=c,
                            batch_size=batch_size,
                            shape=shape,
                            verbose=False,
                            unconditional_guidance_scale=args.scale,
                            unconditional_conditioning=uc,
                            eta=args.ddim_eta,
                            quant_unet=True, cali_data=[_[:batch_size].cuda() for _ in cali_data])
    torch.cuda.empty_cache()

    for name, module in model.named_modules():
        if isinstance(module, QuantModule):
            if module.split == 0:
                module.weight_quantizer.set_inited(True)
            else:
                module.weight_quantizer.set_inited(True)
                module.weight_quantizer_0.set_inited(True)


def set_smooth_quantize_params_Conditional_dpm(
    module,
    cali_data,
    args,
    batch_size: int = 64,
):
    print(f"set_smooth_quantize_params")
    module.model.diffusion_model.set_quant_state(False, False)
    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(False)

    batch_size = min(batch_size, cali_data[0].size(0))

    with torch.no_grad():
        for i in tqdm(range(int(cali_data[0].size(0) / batch_size)), desc="Inited smooth"):
        # for i in tqdm(range(int(batch_size / batch_size)), desc="Inited activation"):
            x_in = torch.cat([cali_data[0][i*batch_size: (i+1)*batch_size]] * 2)
            t_in = torch.cat([cali_data[1][i*batch_size: (i+1)*batch_size]] * 2)
            c_in = torch.cat([cali_data[4][i*batch_size: (i+1)*batch_size], cali_data[3][i*batch_size: (i+1)*batch_size]])
            module.model.diffusion_model.model(
                *[
                    x_in.cuda(), t_in.cuda(), c_in.cuda()
                ]
            )
    torch.cuda.empty_cache()

    for m in module.model.diffusion_model.modules():
        if isinstance(m, QuantModule):
            m.smooth_quantizer.set_inited(True)
            m.set_smooth_training(smooth_is_training=True)

def set_act_quantize_params_Conditional_dpm(
    module,
    all_cali_data, 
    all_t,
    all_index,
    all_cond,
    all_uncond,
    args,
    batch_size: int = 32,
):
    print(f"set_act_quantize_params")
    module.model.diffusion_model.set_quant_state(False, True)

    module.model.diffusion_model.set_act_quantize_init(act_init=False)

    batch_size = min(batch_size, all_cali_data[0].size(0))                  

    with torch.no_grad():
        for time in tqdm(range(len(all_cali_data)), desc="Init scale_a"):
            module.model.diffusion_model.set_time(time)
            for i in range(int(all_cali_data[time].size(0) / batch_size)):
                x_in = torch.cat([all_cali_data[time][i*batch_size: (i+1)*batch_size]] * 2)
                t_in = torch.cat([all_t[time][i*batch_size: (i+1)*batch_size]] * 2)
                c_in = torch.cat([all_uncond[time][i*batch_size: (i+1)*batch_size], all_cond[time][i*batch_size: (i+1)*batch_size]])
                module.model.diffusion_model.model(
                    *[
                        x_in.cuda(), t_in.cuda(), c_in.cuda()
                    ]
                )            
            module.model.diffusion_model.set_act_quantize_init(act_init=False)

    torch.cuda.empty_cache()
    module.model.diffusion_model.set_act_quantize_init(act_init=True)

def set_weight_quantize_params_Conditional_dpm(model, cali_data, args):
    print(f"set_weight_quantize_params")
    model.model.diffusion_model.set_quant_state(True, False)

    for name, module in model.model.diffusion_model.named_modules():
        if isinstance(module, QuantModule):
            module.weight_quantizer.set_inited(False)
    batch_size = 2                      
    with torch.no_grad():
        x_in = torch.cat([cali_data[0][:batch_size]] * 2)
        t_in = torch.cat([cali_data[1][:batch_size]] * 2)
        c_in = torch.cat([cali_data[4][:batch_size], cali_data[3][:batch_size]])
        model.model.diffusion_model.model(
            *[
                x_in.cuda(), t_in.cuda(), c_in.cuda()
            ]
        )
    torch.cuda.empty_cache()

    for name, module in model.named_modules():
        if isinstance(module, QuantModule):
            if module.split == 0:
                module.weight_quantizer.set_inited(True)
            else:
                module.weight_quantizer.set_inited(True)
                module.weight_quantizer_0.set_inited(True)

