import torch
import os
import sys
import diffusers
import time
import shutil
import argparse
import logging

from diffusers import PixArtSigmaPipeline
from qdiff.utils import apply_func_to_submodules, seed_everything, setup_logging, DataSaverHook, StopForwardException

from models.customize_pipeline_pixart_sigma import CustomizePixArtSigmaPipeline
from models.customize_pipeline_pixart_transformer_2d import CustomizePixArtTransformer2DModel
# DIRTY: apply monkey patch, since the from_pretrained() method is hard to hack
diffusers.models.PixArtTransformer2DModel = CustomizePixArtTransformer2DModel
diffusers.PixArtSigmaPipeline = CustomizePixArtSigmaPipeline
from diffusers import PixArtSigmaPipeline
from omegaconf import OmegaConf

from qdiff.ours.ours_quant_layer import OursQuantizedLinear
from qdit.quant import quantize_activation_wrapper, Quantizer
from functools import partial
from optimize.train import optimize_rotation_matrix
from types import SimpleNamespace
from evolution.evolution import evolution_search
        

def main(args):
    seed_everything(args.seed)
    # torch.set_grad_enabled(False)
    device="cuda" if torch.cuda.is_available() else "cpu"

    if args.log is not None:
        if not os.path.exists(args.log):
            os.makedirs(args.log)
    log_file = os.path.join(args.log, 'run.log')
    setup_logging(log_file)
    logger = logging.getLogger(__name__)
    

    ckpt_path = args.ckpt if args.ckpt is not None else "./pretrained_models/"
    pipe = PixArtSigmaPipeline.from_pretrained(
        # "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
        ckpt_path,
        torch_dtype=torch.float16
    ).to(device)
    
    # ---- assign quant configs ------
    quant_config = OmegaConf.load(args.quant_config)
    print(quant_config)
    pipe.convert_quant(quant_config)
    model = pipe.transformer

    '''
    INFO: The PTQ process:
    for simple PTQ with dynamic act quant: 
    the weight are quantized with quant_model initialization.
    the act quant params are calculated online. 
    '''
    def init_rotation_(module):
        # act_mask = calib_data[full_name].max(dim=0)[0]  # [T, C], averaged over all timesteps
        # module.get_channel_mask(act_mask)  # set self.channel_mask
        module.get_rotation_matrix()
        # module.update_quantized_weight_rotated_and_scaled()
    
    def turn_off_quant_(module):
        module.quant_mode = False

    def turn_on_quant_(module):
        module.quant_mode = True
    
    def optimize_rotation_matrix_(module, full_name, calib_data, act_quant, samples_index_list):
        logging.info(f' Optimize rotation matrix for {full_name}')

        batch_count = 0
        ins_list, outs_list = [], []
        for calib_x_batch, calib_t_batch, calib_prompt_embeds in calib_data:
        
            data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=True)
            handle = module.register_forward_hook(data_saver)

            try:
                with torch.no_grad():
                    # predict noise model_output
                    _ = model(
                        calib_x_batch,
                        encoder_hidden_states=calib_prompt_embeds.repeat([calib_x_batch.shape[0] // calib_prompt_embeds.shape[0], 1, 1]),
                        encoder_attention_mask=None,
                        timestep=calib_t_batch,
                        added_cond_kwargs={"resolution": None, "aspect_ratio": None},
                        return_dict=False,
                    )[0]
            except StopForwardException:
                pass

            handle.remove()
            ins = data_saver.input_store[0].float().detach()
            outs = data_saver.output_store.float().detach()
            ins_list.append(ins)
            outs_list.append(outs)
            batch_count += 1
        
        module.rotation_matrix = torch.nn.Parameter(module.rotation_matrix)
        optimize_rotation_matrix(module.rotation_matrix, 
                                 module.fp_weight.float().clone(), module.bias.float().clone(),
                                 act_quant, 
                                 batch_count, samples_index_list,
                                 quant_config.optimize,
                                 quant_config.weight,
                                 full_name,
                                 ins_list, outs_list)
        
        del ins_list, outs_list
        torch.cuda.empty_cache()
    
    def update_weights_rotation_(module):
        module.update_weights_rotation()
    
    def turn_on_calib_(module):
        module.is_calib = True
    
    def update_weights_scale_(module):
        module.update_weights_scale()

    def search_channel_permuation_(module, full_name, calib_data, act_quant, reorder_index_dict):
        logging.info(f' Searching channel permuation for {full_name}')

        ins_list, outs_list = [], []
        for calib_x_batch, calib_t_batch, calib_prompt_embeds in calib_data:
            data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=True)
            handle = module.register_forward_hook(data_saver)

            try:
                with torch.no_grad():
                    # predict noise model_output
                    _ = model(
                        calib_x_batch,
                        encoder_hidden_states=calib_prompt_embeds.repeat([calib_x_batch.shape[0] // calib_prompt_embeds.shape[0], 1, 1]),
                        encoder_attention_mask=None,
                        timestep=calib_t_batch,
                        added_cond_kwargs={"resolution": None, "aspect_ratio": None},
                        return_dict=False,
                    )[0]
            except StopForwardException:
                pass

            handle.remove()
            ins = data_saver.input_store[0].float().detach()
            ins = torch.matmul(ins, module.rotate_scale_matrix.float())
            outs = data_saver.output_store.float().detach()
            ins_list.append(ins)
            outs_list.append(outs)

        reoder_index = evolution_search(ins_list, outs_list,
                                        act_quant,
                                        module.weight.clone(), module.bias.float().clone(),
                                        quant_config.evo,
                                        quant_config.weight,
                                        full_name,
                                        device)
        reorder_index_dict[full_name] = reoder_index
    
    def channel_permuation_quant_(module, full_name, reorder_index_dict):
        module.reorder_index = reorder_index_dict[full_name]
        # module.reorder_index = torch.randperm(module.weight.shape[1]).to(device)

        module.weight.data = torch.index_select(module.weight.data, 1, module.reorder_index)
        module.weight.data = module.w_quantizer(module.weight.data)


    # ============================================================================================================================================= #
        

    calib_data = torch.load(os.path.join(args.log, quant_config.calib_data.save_path), weights_only=True)
    # for calib_x_batch, calib_t_batch, calib_prompt_embeds in calib_data:
    #     print('calib_x_batch.shape: ', calib_x_batch.shape)             # [160, 4, 64, 64]
    #     print('calib_t_batch.shape: ', calib_t_batch.shape)             # [160]
    #     print('calib_t_batch: ', calib_t_batch)
    #     print('calib_prompt_embeds.shape: ', calib_prompt_embeds.shape) # [8, 300, 4096]
    samples_index_list = []
    batch_size = quant_config.calib_data.batch_size
    for bias in range(batch_size*2):
        samples_index_list.append(torch.arange(0, batch_size*2*args.num_sampling_steps, batch_size*2) + bias)

    print('samples_index_list: ', samples_index_list)

    # logging.info("Getting Hadamard Matrix ...")
    kwargs = {}
    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=init_rotation_,
                        **kwargs
                        )
    
    logging.info("Optimizing Hadamard Matrix ...")
    dynamic_group_act_quant_params = {
        'abits': quant_config.act['n_bits'],
        # 'act_group_size': quant_config.act['group_size'],
        'act_group_size': 0,
        'tiling': 0,
        'a_sym': quant_config.act['sym'],
        'a_clip_ratio': 1.0,
        'quant_type': 'int',
        'static': False
    }
    dynamic_group_act_quant_params = SimpleNamespace(**dynamic_group_act_quant_params)
    act_quant = partial(quantize_activation_wrapper, args=dynamic_group_act_quant_params)

    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=turn_off_quant_,
                        **kwargs
                        )

    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=optimize_rotation_matrix_,
                        full_name='',
                        calib_data=calib_data,
                        act_quant=act_quant,
                        samples_index_list=samples_index_list,
                        **kwargs
                        )
    
    logging.info("Updating Weights Rotation ...")
    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=update_weights_rotation_,
                        **kwargs
                        )
    
    logging.info("Calibrating Activation Ranges ...")
    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=turn_on_quant_,
                        **kwargs
                        )
    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=turn_on_calib_,
                        **kwargs
                        )
    for calib_x_batch, calib_t_batch, calib_prompt_embeds in calib_data:
        print('Calibrating new batch...')
        with torch.no_grad():
            _ = model(
                calib_x_batch,
                encoder_hidden_states=calib_prompt_embeds.repeat([calib_x_batch.shape[0] // calib_prompt_embeds.shape[0], 1, 1]),
                encoder_attention_mask=None,
                timestep=calib_t_batch,
                added_cond_kwargs={"resolution": None, "aspect_ratio": None},
                return_dict=False,
            )[0]
    
    logging.info("Updating Weights Scaling ...")
    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=update_weights_scale_,
                        **kwargs
                        )
    
    logging.info("Searching Channel Permutation ...")
    dynamic_group_act_quant_params.act_group_size = quant_config.act['group_size']
    act_quant = partial(quantize_activation_wrapper, args=dynamic_group_act_quant_params)
    reorder_index_dict = {}

    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=search_channel_permuation_,
                        full_name='',
                        calib_data=calib_data[:len(calib_data) // 2],
                        act_quant=act_quant,
                        reorder_index_dict=reorder_index_dict,
                        **kwargs
                        )
    apply_func_to_submodules(model,
                        class_type=OursQuantizedLinear,  # add hook to all objects of this cls
                        function=channel_permuation_quant_,
                        full_name='',
                        reorder_index_dict=reorder_index_dict,
                        **kwargs
                        )
        
    model.set_init_done()

    logger.info(str(model))

    # read the promts
    prompt_path = args.prompt if args.prompt is not None else "./prompts.txt"
    prompts = []
    with open(prompt_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            prompts.append(line.strip())
                    
    N_batch = len(prompts) // args.batch_size # drop_last
    for i in range(N_batch):
        images = pipe(
            prompt=prompts[i*args.batch_size: (i+1)*args.batch_size],
            num_inference_steps=args.num_sampling_steps,
            generator=torch.Generator(device="cuda").manual_seed(args.seed),
        ).images
        print(f"Export image of batch {i}")

        save_path = os.path.join(args.log, "generated_images")
        if not os.path.exists(save_path):
            os.makedirs(save_path)
            
        for i_image in range(args.batch_size):
            images[i_image].save(os.path.join(save_path, f"output_{i_image + args.batch_size*i}.jpg"))
    
    model.save_quant_param_dict()
    torch.save(pipe.transformer.quant_param_dict, os.path.join(args.log, 'quant_params_w4a8.pth'))
    logger.info(f'saved quant params into {args.log}')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--log", type=str, default='./log/w4a8')
    parser.add_argument('--quant-config', type=str, default='./configs/w4a8_gp_ours.yaml')
    parser.add_argument("--cfg-scale", type=float, default=4.5)
    parser.add_argument("--num-sampling-steps", type=int, default=20)
    parser.add_argument("--prompt", type=str, default='./assets/coco_1024.txt')
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ckpt", type=str, default='path/to/your/PixArt-XL-2-512x512/')
    parser.add_argument("--batch-size", type=int, default=1)
    args = parser.parse_args()
    main(args)
