import torch
import os
import sys
import diffusers
import time
import shutil
import argparse
import logging

from diffusers.utils import export_to_video
from qdiff.utils import apply_func_to_submodules, seed_everything, setup_logging

from models.customize_pipeline_flux import CustomizeFluxPipeline
from models.customize_flux_attn_processor import CustomizeFluxAttnProcessor2_0
from models.customize_flux_transformer_2d import CustomizeFluxSingleTransformerBlock, CustomizeFluxTransformerBlock, CustomizeFluxTransformer2DModel
from models.attn_eval_utils import evaluate_attention_maps

# DIRTY: apply monkey patch, since the from_pretrained() method is hard to hack
diffusers.models.attention_processor.FluxAttnProcessor2_0 = CustomizeFluxAttnProcessor2_0
diffusers.models.FluxTransformer2DModel = CustomizeFluxTransformer2DModel
diffusers.FluxPipeline = CustomizeFluxPipeline
from diffusers import FluxPipeline
from omegaconf import OmegaConf

import os
from pathlib import Path
os.environ["TMPDIR"] = "/home/models/tmp/" 

# For Flux, it actually means 2 permutation ways.
permutation_names = [
    "FHW",
    "FWH",
    "HWF",
    "HFW",
    "WHF",
    "WFH",
    ]

import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots

def main(args):
    seed_everything(args.seed)
    torch.set_grad_enabled(False)
    device="cuda" if torch.cuda.is_available() else "cpu"

    if args.log is not None:
        if not os.path.exists(args.log):
            os.makedirs(args.log)
    log_file = os.path.join(args.log, 'sparse.log')
    setup_logging(log_file)
    logger = logging.getLogger(__name__)
    
    config = OmegaConf.load(args.config)
    
    # copy bakup files into log folder
    import shutil
    shutil.copy(args.config, args.log)
    if os.path.exists(os.path.join(args.log,'models')):
        shutil.rmtree(os.path.join(args.log,'models'))
    shutil.copytree('./models', os.path.join(args.log,'models'))

    # INFO: when determine sparse plan, should have permute plan
    assert config.attn.get("sparse",None) is not None
    if config.attn.sparse.get("permute_plan", None) is None:
        permute_plan_path = os.path.join(args.log, "permute_plan.pth")
    else:
        permute_plan_path = config.attn.sparse.permute_plan
    assert os.path.isfile(permute_plan_path), "when determine sparse plan, should have permute plan."
    permute_plan = torch.load(permute_plan_path)
    is_empty = permute_plan["empty"]
    
    # INFO: the calib_data may be too big for GPU memory, so only move the necessary part (iter through each block) to GPU. 
    calib_data = torch.load(args.calib_data, map_location=device)
    N_block = len(calib_data)
    N_prompt, N_timestep, N_cfg, N_head, N, N = calib_data[0].shape

    for i_block in range(N_block):
        calib_data[i_block] = calib_data[i_block].permute([1,0,2,3,4,5])
        if config.sparse_plan.sparse.get("timestep_wise", False):
            calib_data[i_block] = calib_data[i_block].reshape([N_timestep, N_prompt**N_cfg, N_head, N, N]).max(dim=1)[0]   # [N_timestep, N_head, N, N]
        else:
            calib_data[i_block] = calib_data[i_block].reshape([N_timestep*N_prompt**N_cfg, N_head, N, N]).max(dim=0)[0].unsqueeze(1)   # [1, N_head, N, N]
    reordered_data = calib_data
    
    downsample_rate_for_plot = 8
    if args.plot:
        args.plot = True  # determine whether to plot for each stage.
    else:
        args.plot = False
    
    """ 
    Step 3: get the block-sparse mask.
    with the given permuted Qattnmap with corresponding block_size, for each smaller block_size (config.sparse_block_size:)  the actual block_size is attn_ds_rate in calib_data multiplied by sparse_block_size. (in this case is 5*2=10, the block has 10x10)
    determine which block could be sparse, with the criterion of:
    (1) if have any very lagre value, pass. (max_threshold)
    (2) if the block_sum is much smaller than the block_sum_mean by K. (block_sum_k)    
    """  
    sparse_block_size = config.sparse_plan.sparse.sparse_block_size
    num_sparse_block = N // sparse_block_size 
    sparse_block_sum_k = config.sparse_plan.sparse.block_sum_k   # the block sum smaller than overall_mean/K is dropped. 
    sparse_block_max_threshold = config.sparse_plan.sparse.max_threshold  # if any block have value larger than threshold, it is preserved. 
    
    assert N % sparse_block_size == 0
    
    sparse_mask_rates = torch.zeros([N_timestep, N_block, N_head] , device=device)
    sparse_masks = torch.zeros([N_timestep, N_block, N_head, N, N], device=device)
    
    for i_timestep in range(reordered_data[0].shape[0]):  # when timestep_wise=False, it has shape [1]. 
        logger.info(f"---- for {i_timestep}-th timestep -----")
        for i_block in range(N_block):
            # reshape the attn_map according to block_size, get block_sum
            data_ = reordered_data[i_block][i_timestep].reshape([
                N_head,
                num_sparse_block, 
                sparse_block_size, 
                num_sparse_block,
                sparse_block_size 
                ]).permute([0,1,3,2,4])

            sparse_block_max = data_.max(dim=-1)[0].max(dim=-1)[0]
            sparse_mask_large_value = sparse_block_max > sparse_block_max_threshold  # [N_head, num_sparse_block, num_sparse_block]
            large_value_rate = sparse_mask_large_value.sum() / sparse_mask_large_value.numel()
            large_value_rate_per_head = sparse_mask_large_value.sum(dim=(1,2)) / sparse_mask_large_value[0].numel()
            
            sparse_block_sum = data_.sum(dim=-1).sum(dim=-1)  # [num_sparse_block, num_sparse_block]
            # for each head, get the mean
            sparse_block_mean = sparse_block_sum.reshape([N_head, -1]).mean(dim=-1).unsqueeze(-1).unsqueeze(-1)  # [N_head, 1, 1]
        
            """
            2 types:
            "threshold": offline determined
            "mean": use mean for each head
            """
            if config.sparse_plan.sparse.sparse_type == "mean":
                sparse_mask_small_sum = sparse_block_sum > (sparse_block_mean/sparse_block_sum_k)
            elif config.sparse_plan.sparse.sparse_type == "threshold":
                sparse_mask_small_sum = sparse_block_sum > config.sparse_plan.sparse.sum_threshold
            else:
                raise NotImplementedError

            small_sum_rate = sparse_mask_small_sum.sum() / sparse_mask_small_sum.numel()
            small_sum_rate_per_heard = sparse_mask_small_sum.sum(dim=(1,2)) / sparse_mask_small_sum[0].numel()
            
            sparse_mask = torch.logical_or(sparse_mask_large_value, sparse_mask_small_sum)
            sparse_mask_rate = sparse_mask.sum() / sparse_mask.numel()
            sparse_mask_rate_per_head = sparse_mask.sum(dim=(1,2)) / sparse_mask[0].numel()
            # sparse_mask_upsampled = sparse_mask.unsqueeze(2).unsqueeze(4).repeat(1,1,2,1,2).reshape([N_head, N, N])
            sparse_mask_rates[i_timestep,i_block] = sparse_mask_rate_per_head
            sparse_masks[i_timestep,i_block,:,:,:] = sparse_mask
            masked_data = reordered_data[i_block][i_timestep]*sparse_mask

            # logger.info(f"The {i_block}-th block, large_value dense rate {large_value_rate:.4f}, small_sum dense rate {small_sum_rate:.4f}, overall dense rate {sparse_mask_rate:.4f}")
            
            results = evaluate_attention_maps(masked_data.unsqueeze(0), reordered_data[i_block][i_timestep].unsqueeze(0))
            logger.info(f"The {i_block}-th block metrics: L1 Norm {results['Relative L1 Norm']:.4f}, RMSE: {results['RMSE']:.4f}, Cos: {results['Cosine Similarity']:.4f}")
            
            results_per_head = []
            for i_head in range(N_head):
                results_per_head.append(evaluate_attention_maps(masked_data[i_head].reshape([1,1,N,N]), reordered_data[i_block][i_timestep,i_head].reshape([1,1,N,N])))
                    
            # INFO: plot the figure.
            # for each head, plot 5 plots.
            # before, large_value_mask, small_sum_mask, sparse_mask, after
            if args.plot:
                num_cols = 5
                num_rows = N_head
                fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=[f""])
                for col in range(num_cols):
                    for row in range(num_rows):
                        data_to_plot = [reordered_data[i_block][i_timestep],sparse_mask_large_value.int(),sparse_mask_small_sum.int(),sparse_mask.int(),masked_data]
                        fig.add_trace(
                            go.Heatmap(
                                    z=data_to_plot[col][row,::downsample_rate_for_plot,::downsample_rate_for_plot].float().cpu().numpy(),
                                    colorscale='viridis',  # 使用 Viridis 颜色映射
                                    showscale=False,
                                    zmin=1.e-10,
                                    zmax=1.,
                                ),row=row+1, col=col+1
                        )
                        if col == 1:
                            xais_text = f"Large Value <br>Dense: {100*large_value_rate_per_head[row]:.2f} %"
                        elif col == 2:
                            xais_text = f"[{config.sparse_plan.sparse.sparse_type}] Small Sum <br>Dense: {100*small_sum_rate_per_heard[row]:.2f} %"
                        elif col == 3:
                            xais_text = f"Overall <br>Dense: {100*sparse_mask_rate_per_head[row]:.2f} %"
                        elif col == 0:
                            # additionally, note empty head
                            if is_empty[i_block, row]:
                                xais_text = "Empty"
                            else:
                                xais_text = ''
                        elif col == 4:
                            l1_norm = results_per_head[row]['Relative L1 Norm']
                            rmse = results_per_head[row]['RMSE']
                            cos = results_per_head[row]['Cosine Similarity']
                            xais_text = f"L1 Norm: {l1_norm:.4f} <br> RMSE: {rmse:.4f} <br> Cos: {cos:.4f}"
                        else:
                            xais_text = ''
                            
                        fig.update_xaxes(
                            title_text=xais_text,  
                            scaleratio=1,  
                            row=row + 1,
                            col=col + 1,
                        )
                        fig.update_yaxes(
                            title_text=f"Head: {row}",  
                            scaleratio=1,  
                            row=row + 1,
                            col=col + 1,
                        )
                        fig.update_layout(
                            title=f"Sparse Mask Visualization",
                            autosize=False,
                            width=200*num_cols,  
                            height=280*num_rows  
                        )
                        # print(f'ploting {i_block} block, {col}-th mask, {row} head')
                        
                save_path = f"{args.log}/sparse_plan/sparse_mask/sparse_mask_block_{i_block}_t{i_timestep}.pdf"
                Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
                fig.write_image(save_path, format="pdf")     
                print('finished saving:',save_path)  
            
    logger.info(f"Finished the Sparse Mask Plan, the dense rates: {sparse_mask_rates}")
    logger.info(f"The fine-grained dense rate: {sparse_mask_rates.mean():.4f}")

    sparse_mask_rates_with_empty = sparse_mask_rates*(is_empty == 0).int()
    logger.info(f"The overall dense rate: {sparse_mask_rates_with_empty.mean():.4f}")
    
    if config.attn.sparse.get("rescale_text_embeds", False):
        """(optional) get the image_token part relative change, for rescale text embed compensation"""
        rescale_rows = torch.zeros([N_timestep, N_block, N_head, num_sparse_block], device=sparse_masks.device)
        rescale_cols = torch.zeros([N_timestep, N_block, N_head, num_sparse_block], device=sparse_masks.device)
        for i_timestep in range(reordered_data[0].shape[0]): 
            for i_block in range(N_block):
                # reshape the attn_map according to block_size, get block_sum
                data_ = reordered_data[i_block][i_timestep]           
                sparse_mask_ = sparse_masks[i_timestep, i_block]      
                rescale_rows_ = ((data_*sparse_mask_)/data_).mean(dim=-1)   
                rescale_rows[i_timestep, i_block] = rescale_rows_
                rescale_cols_ = ((data_*sparse_mask_)/data_).mean(dim=-2) 
                rescale_cols[i_timestep, i_block] = rescale_cols_ 
                logger.info(f't-{i_timestep}, block-{i_block}, no_rescale rows:{(rescale_rows_ == 1).sum() / rescale_rows_.numel():.4f} %') 
                logger.info(f't-{i_timestep}, block-{i_block}, no_rescale cols:{(rescale_rows_ == 1).sum() / rescale_rows_.numel():.4f} %') 
                logger.info("-------------")
                
                
                
                

    """Final: conclude the sparse plan, plot the relative error and sparse rates. """
    save_d = {}
    save_path = f'{args.log}/sparse_plan.pth'
    save_d['sparse'] = sparse_masks
    save_d['dense_rate'] = sparse_mask_rates
    if config.attn.sparse.get("rescale_text_embeds", False):
        save_d['rescale_rows'] = rescale_rows
        save_d['rescale_cols'] = rescale_cols
        
    torch.save(save_d, save_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--log", type=str, default='debug')
    parser.add_argument('--config', default=None, type=str)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--calib_data", type=str, default=None)
    parser.add_argument("--plot", action="store_true")
    args = parser.parse_args()
    main(args)