import os
import sys
import time
import logging
import argparse
import shutil
import math

import numpy as np
from scipy import stats
import torch
import diffusers
from omegaconf import OmegaConf, ListConfig

from models.customize_pipeline_pixart_sigma import CustomizePixArtSigmaPipeline
from models.customize_pipeline_pixart_transformer_2d import CustomizePixArtTransformer2DModel
sys.path.insert(0, sys.path[0] + '/../../')
from quant_utils.utils import apply_func_to_submodules, seed_everything, setup_logging
diffusers.models.PixArtTransformer2DModel = CustomizePixArtTransformer2DModel
diffusers.PixArtSigmaPipeline = CustomizePixArtSigmaPipeline
from diffusers import PixArtSigmaPipeline


def get_incoherence_norm(activation):
    shape = activation.shape
    incoherence_list = []
    for i in range(shape[0]):
        act_cur = activation[i].flatten()
        incoherence = act_cur.max().item() * math.sqrt(act_cur.numel()) / torch.norm(act_cur, p='fro').item()
        incoherence_list.append(incoherence)
    incoherence_list = np.array(incoherence_list)
    
    incoherence_list_norm = []
    for i in range(shape[0]):
        incoherence_list_norm.append(math.exp(incoherence_list[i]) / sum([math.exp(ss) for ss in incoherence_list]))
    incoherence_list_norm = np.array(incoherence_list_norm)
    return incoherence_list_norm


def main(args):
    seed_everything(args.seed)
    torch.set_grad_enabled(False)
    device="cuda" if torch.cuda.is_available() else "cpu"

    if args.log is not None:
        if not os.path.exists(args.log):
            os.makedirs(args.log)
    log_file = os.path.join(args.log, 'run.log')
    setup_logging(log_file)
    logger = logging.getLogger(__name__)

    pipe = PixArtSigmaPipeline.from_pretrained(
        "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
        torch_dtype=torch.float16 # due to CUDA kernel only supports fp16, we donot use bfloat16 here. 
    ).to(device)

    # INFO: if memory intense
    # pipe.enable_model_cpu_offload()
    # pipe.vae.enable_tiling()
    
    # ---- assign quant configs ------
    quant_config = OmegaConf.load(args.quant_config)
    pipe.convert_quant(quant_config)
    model = pipe.transformer
    
    def init_kl_rotation_matrix_(module, full_name, save_dir, logger, argument_method):
            assert isinstance(module, VetaQuarotQuantizedLinear)
            
            file_name = full_name.replace('.', '_') + '.pt'
            file_path = os.path.join(save_dir, file_name)
            
            if not os.path.exists(file_path):
                logger.warning(f"File {file_path} not found, skip.")
                return

            logger.info(f"Loading data for layer {full_name} from {file_path}")
            kl_calib_data = torch.load(file_path)
            
            if args.smooth_quant:
                act_mask = kl_calib_data.reshape(-1, kl_calib_data.shape[-1]).abs().max(dim=0)[0]  # [T, C], averaged over all timesteps
                module.get_channel_mask(act_mask)  # set self.channel_mask
                module.update_quantized_weight_scaled()
            
            if argument_method is not None:
                if argument_method == "inco":
                    if args.smooth_quant:
                        kl_calib_data_scaled = kl_calib_data * module.channel_mask.reshape([1, 1, -1]).to(kl_calib_data.device)
                        norm = get_incoherence_norm(kl_calib_data_scaled)
                    else:
                        norm = get_incoherence_norm(kl_calib_data)
                else:
                    raise ValueError(f"Unsupported argument method: {argument_method}")
            
                activation_norm = torch.zeros(kl_calib_data[0].shape)
                for i in range(kl_calib_data.shape[0]):
                    if args.smooth_quant:
                        activation_norm += norm[i] * kl_calib_data_scaled[i]
                    else:
                        activation_norm += norm[i] * kl_calib_data[i]
            else:
                if args.smooth_quant:
                    kl_calib_data_scaled = kl_calib_data * module.channel_mask.reshape([1, 1, -1]).to(kl_calib_data.device)
                    activation_norm = kl_calib_data_scaled.reshape(-1, kl_calib_data_scaled.shape[-1])
                else:
                    activation_norm = kl_calib_data.reshape(-1, kl_calib_data.shape[-1])
            
            module.get_rotation_matrix(activation_norm)
            
            module.update_quantized_weight_rotated(args.gptq)
            
            del kl_calib_data
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            logger.info(f"Layer {full_name} finished.")

            
    if quant_config.get("veta", None) is not None:
        from quant_utils.veta.veta_quant_layer import VetaQuarotQuantizedLinear
        
        assert quant_config.calib_data.save_path is not None
        calib_data_dir = os.path.join(args.log, quant_config.calib_data.save_path)
        
        # 确保元数据文件存在
        metadata_file = os.path.join(calib_data_dir, 'metadata.pt')
        if not os.path.exists(metadata_file):
            logger.error(f"Metadata file {metadata_file} not found.")
            raise FileNotFoundError(f"Metadata file {metadata_file} not found.")
        
        # get the rotation matrix, iter through all layers
        kwargs = {}
        apply_func_to_submodules(
            model,
            class_type=VetaQuarotQuantizedLinear,  # add hook to all objects of this cls
            function=init_kl_rotation_matrix_,
            save_dir=calib_data_dir,
            logger=logger,
            argument_method=args.argument_method,
            full_name='',
            **kwargs
        )
    else:
        raise NotImplementedError("Only VETA is supported for now")
    
    model.set_init_done()
    
    pipe.model = model

    logger.info(str(model))

    # read the promts
    prompt_path = args.prompt if args.prompt is not None else "./prompts.txt"
    prompts = []
    with open(prompt_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            prompts.append(line.strip())
                    
    N_batch = len(prompts) // args.batch_size # drop_last
    for i in range(N_batch):
        images = pipe(
            prompt=prompts[i*args.batch_size: (i+1)*args.batch_size],
            num_inference_steps=args.num_sampling_steps,
            generator=torch.Generator(device="cuda").manual_seed(args.seed),
        ).images
        print(f"Export image of batch {i}")

        save_path = os.path.join(args.log, "generated_images")
        if not os.path.exists(save_path):
            os.makedirs(save_path)
            
        for i_image in range(args.batch_size):
            images[i_image].save(os.path.join(save_path, f"output_{i_image + args.batch_size*i}.jpg"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--log", type=str)
    parser.add_argument('--quant-config', required=True, type=str)
    parser.add_argument("--quant_param_ckpt", type=str, default="./quant_params.pth")
    parser.add_argument("--cfg-scale", type=float, default=4.5)
    parser.add_argument("--num-sampling-steps", type=int, default=20)
    parser.add_argument("--prompt", type=str, default=None)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--hardware", action='store_true', help='whether to use_cuda_kernel')
    parser.add_argument("--profile", action='store_true', help='profile mode, measure the e2e latency')
    parser.add_argument("--quant_weight_ckpt", type=str, default=None)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--smooth_quant", action='store_true', help="Use smooth quantization")
    parser.add_argument("--gptq", action='store_true', default=False, help="Use gptq quantization")
    parser.add_argument("--argument_method", type=str, default=None)
    args = parser.parse_args()
    main(args)