import torch
import os
import sys
import diffusers
import time
import shutil
import argparse
import logging

from diffusers.utils import export_to_video
from qdiff.utils import apply_func_to_submodules, seed_everything, setup_logging

from models.customize_pipeline_flux import CustomizeFluxPipeline
from models.customize_flux_attn_processor import CustomizeFluxAttnProcessor2_0
from models.customize_flux_transformer_2d import CustomizeFluxSingleTransformerBlock, CustomizeFluxTransformerBlock, CustomizeFluxTransformer2DModel

# DIRTY: apply monkey patch, since the from_pretrained() method is hard to hack
diffusers.models.attention_processor.FluxAttnProcessor2_0 = CustomizeFluxAttnProcessor2_0
diffusers.models.FluxTransformer2DModel = CustomizeFluxTransformer2DModel
diffusers.FluxPipeline = CustomizeFluxPipeline
from diffusers import FluxPipeline
from omegaconf import OmegaConf

def main(args):
    seed_everything(args.seed)
    torch.set_grad_enabled(False)
    device="cuda" if torch.cuda.is_available() else "cpu"

    if args.log is not None:
        if not os.path.exists(args.log):
            os.makedirs(args.log)
    log_file = os.path.join(args.log, 'run.log')
    setup_logging(log_file)
    logger = logging.getLogger(__name__)
    
    # INFO: backup a few files
    import shutil
    shutil.copy(args.quant_config, args.log)
    if os.path.exists(os.path.join(args.log,'models')):
        shutil.rmtree(os.path.join(args.log,'models'))
    shutil.copytree('./models', os.path.join(args.log,'models'))
    
    ckpt_path = args.ckpt if args.ckpt is not None else "/home/models/flux/"
    pipe = FluxPipeline.from_pretrained(
        ckpt_path,
        torch_dtype=torch.bfloat16,
    ).to(device)
    
    # ---- assign quant configs ------
    logger.info(f"loading config from {args.quant_config}")
    quant_config = OmegaConf.load(args.quant_config)
    logger.info(quant_config)
    
    # INFO: default path of permute_plan and sparse_plan, if used
    if quant_config.attn.get("sparse", None) is not None:
        if quant_config.attn.sparse.get("permute", False):
            # the permute mode on
            if quant_config.attn.sparse.get("permute_plan", None) is None:
                # default using log_path
                assert os.path.isfile(os.path.join(args.log, "permute_plan.pth")) # file exists
                quant_config.attn.sparse.permute_plan = os.path.join(args.log, "permute_plan.pth")
                
        if quant_config.attn.sparse.get("block_sparse", False):
            # use static mask
            if quant_config.attn.sparse.get("sparse_plan", None) is None:
                # default using log_path
                assert os.path.isfile(os.path.join(args.log, "sparse_plan.pth")) # file exists
                quant_config.attn.sparse.sparse_plan = os.path.join(args.log, "sparse_plan.pth")
        
    if args.export_calib_data is not None:
        assert quant_config.calib_data.get("save_path", None) is None  # no save_path in config.
        quant_config.calib_data.save_path = args.export_calib_data
        quant_config.export_calib_data = True
    pipe.convert_quant(quant_config)
    
    pipe.transformer.set_init_done()
    pipe.transformer.save_quant_param_dict()
    torch.save(pipe.transformer.quant_param_dict, os.path.join(args.log, 'quant_params.pth'))

    # INFO: if memory intense
    # pipe.enable_model_cpu_offload()
    # pipe.vae.enable_tiling()

    # read the promts
    prompt_path = args.prompt if args.prompt is not None else "./prompts.txt"
    prompts = []
    with open(prompt_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            prompts.append(line.strip())

    N_batch = len(prompts) // args.batch_size # drop_last
    for i in range(N_batch):
        images = pipe(
            prompt=prompts[i*args.batch_size: (i+1)*args.batch_size],
            num_inference_steps=args.num_sampling_steps,
            generator=torch.Generator(device="cuda").manual_seed(args.seed),
        ).images
        print(f"Export image of batch {i}")

        save_path = os.path.join(args.log, f"generated_images_{args.num_sampling_steps}")
        print(save_path)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
            
        for i_image in range(args.batch_size):
            images[i_image].save(os.path.join(save_path, f"output_{i_image + args.batch_size*i}.jpg"))
        logger.info(f"Export image to {save_path}/output_{i}.jpg")
        if quant_config.attn.get("sparse",None) is not None:
            if quant_config.attn.sparse.get("block_sparse",False):
                # INFO: iter through all SparseAttnProcessor, and get the averaged sparse rate.
                sparse_rates = []
                for i_block in range(len(pipe.transformer.transformer_blocks)):
                    sparse_rate_ = pipe.transformer.transformer_blocks[i_block].attn.processor.attn_map_sparse_processor.dense_rate_accumulator
                    sparse_rate_ = torch.tensor(sparse_rate_)
                    sparse_rates.append(sparse_rate_)
                for i_block in range(len(pipe.transformer.single_transformer_blocks)):
                    sparse_rate_ = pipe.transformer.single_transformer_blocks[i_block].attn.processor.attn_map_sparse_processor.dense_rate_accumulator
                    sparse_rate_ = torch.tensor(sparse_rate_)
                    sparse_rates.append(sparse_rate_)
                sparse_rates = torch.stack(sparse_rates, dim=0)
                logger.info(f'overall dense rate: {sparse_rates.mean():.4f}')
        
    if args.export_calib_data is not None:
        """
        Export the Attention Map / QKV for SparseAttn and QuantizedAttn
        """
        num_heads = pipe.transformer.num_attention_heads
        head_split_num = quant_config.attn.head_split_num
        num_splits = num_heads // head_split_num
        assert head_split_num*num_splits == num_heads
        
        if quant_config.calib_data.attn_map:
            # list: [N_prompt, N_time, N_splits]
            save_d = []
            num_blocks = len(pipe.transformer.transformer_blocks)
            num_single_blocks = len(pipe.transformer.single_transformer_blocks)
            for i_block in range(num_blocks):
                for hook_type in pipe.transformer.transformer_blocks[i_block].attn.processor.hooks.keys():
                    hook_data = pipe.transformer.transformer_blocks[i_block].attn.processor.hooks[hook_type].outputs
                    hook_data_inner_shape = hook_data[0].shape  # should be in [BS(2), num_splits, N, N/N_dim]
                    len_hook_data = len(hook_data)
                    assert len_hook_data == len(prompts)*args.num_sampling_steps*head_split_num
                    hook_data_outer_shape = [len(prompts), args.num_sampling_steps, head_split_num]
                    # for qkv and attn_map, the reshape process are similar
                    hook_data = torch.stack(hook_data, dim=0).reshape(hook_data_outer_shape + list(hook_data_inner_shape))
                    hook_data = hook_data.permute([0,1,3,2,4,5,6]).reshape([
                        len(prompts), args.num_sampling_steps, hook_data_inner_shape[0], num_heads, hook_data_inner_shape[2], hook_data_inner_shape[3]
                    ])  # [N_prompt, N_timestep, 2(CFG), N_head, N, N/N_dim]
                    
                    save_d.append(hook_data)
            for i_block in range(num_single_blocks):
                for hook_type in pipe.transformer.single_transformer_blocks[i_block].attn.processor.hooks.keys():
                    hook_data = pipe.transformer.single_transformer_blocks[i_block].attn.processor.hooks[hook_type].outputs
                    hook_data_inner_shape = hook_data[0].shape  # should be in [BS(2), num_splits, N, N/N_dim]
                    len_hook_data = len(hook_data)
                    assert len_hook_data == len(prompts)*args.num_sampling_steps*head_split_num
                    hook_data_outer_shape = [len(prompts), args.num_sampling_steps, head_split_num]
                    # for qkv and attn_map, the reshape process are similar
                    hook_data = torch.stack(hook_data, dim=0).reshape(hook_data_outer_shape + list(hook_data_inner_shape))
                    hook_data = hook_data.permute([0,1,3,2,4,5,6]).reshape([
                        len(prompts), args.num_sampling_steps, hook_data_inner_shape[0], num_heads, hook_data_inner_shape[2], hook_data_inner_shape[3]
                    ])  # [N_prompt, N_timestep, 2(CFG), N_head, N, N/N_dim]
                    
                    save_d.append(hook_data)
            save_path = f'./visualization/calib_data/{quant_config.calib_data.save_path}.pth'
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(save_d, save_path)
        
        if quant_config.calib_data.qkv:
            # list: [N_prompt, N_time, N_splits]
            save_d = {}
            for hook_type in pipe.transformer.transformer_blocks[0].attn.processor.hooks.keys():
                save_d[hook_type] = []
            num_blocks = len(pipe.transformer.transformer_blocks)
            num_single_blocks = len(pipe.transformer.single_transformer_blocks)
            for i_block in range(num_blocks):
                for hook_type in pipe.transformer.transformer_blocks[i_block].attn.processor.hooks.keys():
                    hook_data = pipe.transformer.transformer_blocks[i_block].attn.processor.hooks[hook_type].outputs
                    hook_data_inner_shape = hook_data[0].shape  # should be in [BS(2), num_splits, N, N/N_dim]
                    len_hook_data = len(hook_data)
                    assert len_hook_data == len(prompts)*args.num_sampling_steps*head_split_num
                    hook_data_outer_shape = [len(prompts), args.num_sampling_steps, head_split_num]
                    # for qkv and attn_map, the reshape process are similar
                    hook_data = torch.stack(hook_data, dim=0).reshape(hook_data_outer_shape + list(hook_data_inner_shape))
                    hook_data = hook_data.permute([0,1,3,2,4,5,6]).reshape([
                        len(prompts), args.num_sampling_steps, hook_data_inner_shape[0], num_heads, hook_data_inner_shape[2], hook_data_inner_shape[3]
                    ])  # [N_prompt, N_timestep, 2(CFG), N_head, N, N/N_dim]
                    
                    save_d[hook_type].append(hook_data)
            for i_block in range(num_single_blocks):
                for hook_type in pipe.transformer.single_transformer_blocks[i_block].attn.processor.hooks.keys():
                    hook_data = pipe.transformer.single_transformer_blocks[i_block].attn.processor.hooks[hook_type].outputs
                    hook_data_inner_shape = hook_data[0].shape  # should be in [BS(2), num_splits, N, N/N_dim]
                    len_hook_data = len(hook_data)
                    assert len_hook_data == len(prompts)*args.num_sampling_steps*head_split_num
                    hook_data_outer_shape = [len(prompts), args.num_sampling_steps, head_split_num]
                    # for qkv and attn_map, the reshape process are similar
                    hook_data = torch.stack(hook_data, dim=0).reshape(hook_data_outer_shape + list(hook_data_inner_shape))
                    hook_data = hook_data.permute([0,1,3,2,4,5,6]).reshape([
                        len(prompts), args.num_sampling_steps, hook_data_inner_shape[0], num_heads, hook_data_inner_shape[2], hook_data_inner_shape[3]
                    ])  # [N_prompt, N_timestep, 2(CFG), N_head, N, N/N_dim]
                    
                    save_d[hook_type].append(hook_data)
            
            save_path = f'./visualization/calib_data/{quant_config.calib_data.save_path}.pth'
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(save_d, save_path)

                    
                
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--log", type=str)
    parser.add_argument('--quant-config', default=None, type=str)
    parser.add_argument("--cfg-scale", type=float, default=4.0)
    parser.add_argument("--num-sampling-steps", type=int, default=30)  # default: 30
    parser.add_argument("--prompt", type=str, default=None)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--export-calib-data", type=str, help='store the intermediate activations as calib_data. ')
    parser.add_argument("--batch-size", type=int, default=1)
    args = parser.parse_args()
    main(args)
