import torch
import os
import sys
import diffusers
import time
import shutil
import argparse
import logging

from diffusers.utils import export_to_video
from qdiff.utils import apply_func_to_submodules, seed_everything, setup_logging

from models.customize_pipeline_flux import CustomizeFluxPipeline
from models.customize_flux_attn_processor import CustomizeFluxAttnProcessor2_0
from models.customize_flux_transformer_2d import CustomizeFluxSingleTransformerBlock, CustomizeFluxTransformerBlock, CustomizeFluxTransformer2DModel
from models.attn_eval_utils import evaluate_attention_maps

# DIRTY: apply monkey patch, since the from_pretrained() method is hard to hack
diffusers.models.attention_processor.FluxAttnProcessor2_0 = CustomizeFluxAttnProcessor2_0
diffusers.models.FluxTransformer2DModel = CustomizeFluxTransformer2DModel
diffusers.FluxPipeline = CustomizeFluxPipeline
from diffusers import FluxPipeline
from omegaconf import OmegaConf

import os
from pathlib import Path
os.environ["TMPDIR"] = "/home/models/tmp/" 

permutation_names = [
    "FHW",
    "FWH",
    "HWF",
    "HFW",
    "WHF",
    "WFH",
    ]

def get_empty_head(attn_out, max_threshold):
    """Identify the empty_head"""
    N, N = attn_out.shape
    
    # if global max is small enough (pure empty head)
    attn_out_max = attn_out.max()
    # print('The global max is: {:.3f}'.format(attn_out.max()))
    
    is_empty = (attn_out_max < max_threshold).int() 
    
    # the global sum value (sum to one with tokens, smaller indicates sink to text tokens)
    # attn_out_sum = int((attn_out.sum() / N)*100)
    # print(f'The sum of image token attn part: {attn_out_sum}%')
    
    return is_empty

def get_permute_plan(attn_out, max_threshold, sparse_percentage):
    """
    -- return values ---
    chosen_permute: [] single int value within 0~6.
    sparse_rates: [6]
    reordered_attn_map: [6, N, N]
    """
    N, N = attn_out.shape
    
    # INFO:
    # 6 types of permutations: FHW, FWH, HWF, HFW, WHF, WFH
    # DIRTY: fixed em. 
    F = 1
    H = 64//4
    W = 64//4
    # the naturally block_sparse pattern block size is the same with the 1st dimension
    # (which is also the last dimension in memory layout), e.g., the WHF has local W block
    permutations_block_sparse_size = [
        N//F,
        N//F,
        N//H,
        N//H,
        N//W,
        N//W
    ]
    
    attn_patches_sparse_rates = torch.zeros([6], device=attn_out.device)
    attn_patches_incoherences = torch.zeros([6], device=attn_out.device)
    permuted_attns = torch.zeros([6, N, N], device=attn_out.device)
    for i_permute in range(6):    
        # permute the attn map. 
        permuted_attn = reorder_attn_out_from_calib_data(attn_out.reshape([1,1,N,N]), permute_order=i_permute).reshape([N,N])
        permuted_attns[i_permute,:,:] = permuted_attn
        
        # split it into patches.
        patch_size = permutations_block_sparse_size[i_permute]
        num_patches = N // patch_size
        attn_patches = permuted_attn.reshape([num_patches,patch_size,num_patches,patch_size]).permute([0,2,1,3])  # [num_patches, num_patches, patch_size, patch_size]
        attn_patches = attn_patches.reshape([num_patches*num_patches, patch_size*patch_size])
        
        # get how many patches could be considered empty. 
        attn_patches_sparse_percentage = (attn_patches < max_threshold).sum(dim=-1) / (patch_size*patch_size) # [num_pacthes]
        attn_patches_is_sparse = (attn_patches_sparse_percentage > sparse_percentage).int()
        attn_patches_sparse_rate = attn_patches_is_sparse.sum() / attn_patches_is_sparse.numel()
        attn_patches_sparse_rates[i_permute] = attn_patches_sparse_rate

        # the [0,1] permutets have the same patch_wise (use a smaller patch) sparse
        # but have different inside block distribution
        # get the incoherence = mean()/max() inside each patch
        patch_size = 2 
        num_patches = N // patch_size
        assert N % patch_size == 0
        attn_patches = permuted_attn.reshape([num_patches,patch_size,num_patches,patch_size]).permute([0,2,1,3])  # [num_patches, num_patches, patch_size, patch_size]
        attn_patches = attn_patches.reshape([num_patches*num_patches, patch_size*patch_size])
        
        attn_patches_incoherence = (attn_patches.max(dim=-1)[0] / attn_patches.mean(dim=-1))
        attn_patches_incoherences[i_permute] = attn_patches_incoherence.mean()
        
    most_sparse_id = torch.argmax(attn_patches_sparse_rates)  # the argmax return the 1st value
    least_incoherence_id = torch.argmin(attn_patches_incoherences)
    
    CHOOSE_PERMUTE_TYPE = 'local_uniform'
    if CHOOSE_PERMUTE_TYPE == 'mixed':
        if attn_patches_incoherences[most_sparse_id] >= attn_patches_incoherences[most_sparse_id+1]:  # incoherence, smaller the better, binary classification.
            chosen_permute = most_sparse_id + 1
        else:
            chosen_permute = most_sparse_id
    elif CHOOSE_PERMUTE_TYPE == 'global_sparse':
        chosen_permute = most_sparse_id
    elif CHOOSE_PERMUTE_TYPE == 'local_uniform':
        chosen_permute = least_incoherence_id
    else:
        raise NotImplementedError
    # print(f'Choosing permute {chosen_permute}, candidate sparse rates {attn_patches_sparse_rates*100}%')
    
    return chosen_permute, attn_patches_sparse_rates, permuted_attns

def reorder_attn_out_from_calib_data(attn_out, permute_order=0):
    """the reorfer tailored for [1,5,5] downsampled attn_out, without the text_tokens"""
    
    BS, N_head, N_token, N_dim = attn_out.shape
        
    # DIRTY: lazy feeding the config in here, so just skip em. 
    F = 1
    H = 64//4
    W = 64//4

    assert N_token == F*W*H
    N_image_token = N_token

    permutations = torch.tensor([
        [0, 1, 2],  # 0: FHW
        [0, 2, 1],  # 1: FWH
        [1, 2, 0],  # 2: HWF
        [1, 0, 2],  # 3: HFW
        [2, 1, 0],  # 4: WHF
        [2, 0, 1],  # 5: WFH
    ])
    permutations_inv = torch.tensor([
        [0, 1, 2],  # 0: FHW
        [0, 2, 1],  # 1: FWH
        [2, 0, 1],  # 2: HWF
        [1, 0, 2],  # 3: HFW
        [2, 1, 0],  # 4: WHF
        [1, 2, 0],  # 5: WFH
    ])

    permute_order_index = torch.ones([N_head], dtype=torch.int)*permute_order
    permute_orders = torch.stack([permutations[i.item()] for i in permute_order_index], dim=0)  # [N_head,3]
    # permute_orders_inv = torch.stack([permutations_inv[i.item()] for i in permute_order_index], dim=0)  # [N_head,3]

    attn_out_reordered = attn_out.clone()
    for i_head in range(N_head):
        
        permute_dims_head = permute_orders[i_head]
        permute_dims_head_extend = tuple([0]+(permute_dims_head+1).tolist()+(permute_dims_head+4).tolist())
                   
        permuted_shape = torch.tensor([BS,F,H,W,F,H,W])
        
        attn_out_reordered[:,i_head,:,:] = attn_out[:,i_head,:,:].reshape(*permuted_shape).permute(*permute_dims_head_extend).reshape([BS,N_token,N_token])
            
    return attn_out_reordered


import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
    
def main(args):
    seed_everything(args.seed)
    torch.set_grad_enabled(False)
    device="cuda" if torch.cuda.is_available() else "cpu"

    if args.log is not None:
        if not os.path.exists(args.log):
            os.makedirs(args.log)
    log_file = os.path.join(args.log, 'sparse.log')
    setup_logging(log_file)
    logger = logging.getLogger(__name__)
    
    config = OmegaConf.load(args.config)
    
    # copy bakup files into log folder
    import shutil
    shutil.copy(args.config, args.log)
    if os.path.exists(os.path.join(args.log,'models')):
        shutil.rmtree(os.path.join(args.log,'models'))
    shutil.copytree('./models', os.path.join(args.log,'models'))
    
    # INFO: the calib_data may be too big for GPU memory, so only move the necessary part (iter through each block) to GPU. 
    calib_data = torch.load(args.calib_data, map_location=device) 
    
    N_block = len(calib_data)
    N_prompt, N_timestep, N_cfg, N_head, N, N = calib_data[0].shape
    
    downsample_rate_for_plot = 16
    if args.plot:
        args.plot = [True, True]  # determine whether to plot for each stage.
    else:
        args.plot = [False, False]
    
    """
    Step 1: get the empty heads.
    if the max is small enough, consider empty head
    fill with all 0., or uniform values. 
    ---- config ----
    empty:
        max_threshold: 5.e-3
    """    
    is_empty = torch.zeros([N_block, N_head], dtype=torch.int, device=device)
    empty_head_max_threshold = config.sparse_plan.empty.max_threshold
    for i_block in range(N_block):
        for i_head in range(N_head):
            data_ = calib_data[i_block].reshape([N_timestep*N_prompt*N_cfg, N_head, N, N]).max(dim=0)[0]
            is_empty_ = get_empty_head(data_[i_head], empty_head_max_threshold)
            is_empty[i_block, i_head] = is_empty_
    logger.info(f"The empty plan: {is_empty}")
    logger.info('The empty head proportion: {:.4f}% '.format((is_empty.sum() / is_empty.numel())*100))
    
    if args.plot[0]:  # WARNING: plotting would take much time. 
        num_rows = N_head
        for i_block in range(N_block):
            fig = make_subplots(rows=num_rows, cols=1, subplot_titles=[f""])
            for row in range(num_rows):
                data_to_plot = calib_data[i_block].cpu().reshape([N_timestep*N_prompt*N_cfg, N_head, N, N]).max(dim=0)[0][row].to(torch.float32)
                fig.add_trace(
                    go.Heatmap(
                            z=data_to_plot[::downsample_rate_for_plot,::downsample_rate_for_plot],
                            colorscale='viridis',  
                            showscale=False,
                            zmin=1.e-10,
                            zmax=1.0,
                        ),row=row+1, col=1
                )
                is_empty_ = is_empty[i_block, row]
                if is_empty_:
                    xais_text = f"Empty <br> Max: {data_to_plot.max():.5f}"
                else:
                    xais_text = ""  
                fig.update_xaxes(
                    title_text=xais_text, 
                    scaleratio=1, 
                    row=row + 1,
                    col=1
                )
                fig.update_yaxes(
                    title_text=f"Head: {row}",  
                    scaleratio=1, 
                    row=row + 1,
                    col=1
                )
                fig.update_layout(
                    title=f"Empty Attn Maps",
                    autosize=False,
                    width=300,  
                    height=280*num_rows  
                )
                # print(f'ploting {i_block} block, {row} head')
            save_path = f"{args.log}/sparse_plan/empty_heads/empty_heads_block_{i_block}.pdf"
            Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
            fig.write_image(save_path, format="pdf")     
            print('finished saving:',save_path)   
            
    """
    Step 2: get the permutation plan for each head.
    split different blocks for different permutations. [N/13, N/6, N/9]
    use the relative block sparse rate (usually determined by max_threshold) to choose the best permutation.
        - how many blocks have maximum value smaller than max_threshold.
        - how many percentage of blocks are the smallest. 
    ---- config ----
    permute:
        max_threshold: 2.e-4
        sparse_percentage: 0.9
    """
    permute_plan = torch.zeros([N_block, N_head], dtype=torch.int)
    permutation_max_threshold = config.sparse_plan.permute.max_threshold
    permutation_sparse_percentage = config.sparse_plan.permute.sparse_percentage
    
    reordered_data = []
    coarse_grained_sparse_rate = []
    for i_block in range(N_block):
        reordered_data_cur_block = []
        all_reordered_attn_map_cur_block = []
        sparse_rate_cur_block = []
        chosen_permute_cur_block = []
        data_ = calib_data[i_block].reshape([N_timestep*N_prompt*N_cfg, N_head, N, N]).max(dim=0)[0]
        for i_head in range(N_head):
            # INFO: exclude the empty heads at first. no need to reorder them. 
            if is_empty[i_block, i_head]:
                chosen_permute = 0
                sparse_rates = torch.ones([6], device=device)
                reordered_attn_map = data_[i_head].unsqueeze(0).repeat(6,1,1)
            else:
                chosen_permute, sparse_rates, reordered_attn_map = get_permute_plan(
                    data_[i_head],
                    permutation_max_threshold,
                    permutation_sparse_percentage,
                )  
            chosen_permute_cur_block.append(chosen_permute)
            permute_plan[i_block, i_head] = chosen_permute
            sparse_rate_cur_block.append(sparse_rates)
            all_reordered_attn_map_cur_block.append(reordered_attn_map)
            reordered_data_cur_block.append(reordered_attn_map[chosen_permute])
        reordered_data.append(torch.stack(reordered_data_cur_block, dim=0))
        
        sparse_rate_cur_block = torch.stack(sparse_rate_cur_block, dim=0)  # [N_head, 6]
        all_reordered_attn_map_cur_block = torch.stack(all_reordered_attn_map_cur_block, dim=0)  # [N_head,6,N,N]
        logger.info(f"The coarse-grained attn_map sparse_rate for {i_block}-block after permute: {sparse_rate_cur_block.mean():.3f}")
        coarse_grained_sparse_rate.append(sparse_rate_cur_block.mean())
            
        # INFO: due to memory issues, 6x all attention map (6x calib data size) plot for each block also.   
        if args.plot[1]:
            num_cols = 6
            num_rows = N_head
            fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=[f""])
            for col in range(num_cols):
                for row in range(num_rows):
                    data_to_plot = all_reordered_attn_map_cur_block[row,col,:,:].cpu().to(torch.float32)
                    fig.add_trace(
                        go.Heatmap(
                                z=data_to_plot[::downsample_rate_for_plot,::downsample_rate_for_plot],
                                colorscale='viridis',  
                                showscale=False,
                                zmin=1.e-10,
                                zmax=1.0,
                            ),row=row+1, col=col+1
                    )
                    if col == chosen_permute_cur_block[row]:
                        xais_text = f"Selected Permute <br>Sparse: {100*sparse_rate_cur_block[row,col]:.2f} %"
                    else:
                        xais_text = ''
                        
                    fig.update_xaxes(
                        title_text=xais_text,  
                        scaleratio=1, 
                        row=row + 1,
                        col=col + 1,
                    )
                    fig.update_yaxes(
                        title_text=f"Head: {row}",  
                        scaleratio=1,  
                        row=row + 1,
                        col=col + 1,
                    )
                    fig.update_layout(
                        title=f"Empty Attn Maps",
                        autosize=False,
                        width=200*num_cols,  
                        height=280*num_rows  
                    )
                    print(f'ploting {i_block} block, {col} permute, {row} head')
            save_path = f"{args.log}/sparse_plan/permutations/permutation_block_{i_block}.pdf"
            Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
            fig.write_image(save_path, format="pdf")     
            print('finished saving:',save_path)  
        
    logger.info(f"Finished the Reorder Plan, the permute plan: {permute_plan}")
    logger.info(f"The overall coarse grained sparse rate: {torch.tensor(coarse_grained_sparse_rate).mean():.4f}")
          
    """Final: conclude the sparse plan, plot the relative error and sparse rates. """
    save_d = {}
    save_path = f'{args.log}/permute_plan.pth'
    save_d['empty'] = is_empty
    save_d['permute'] = permute_plan
    torch.save(save_d, save_path)
                
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--log", type=str, default='debug')
    parser.add_argument('--config', default=None, type=str)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--calib_data", type=str, default=None)
    parser.add_argument("--plot", action="store_true")
    args = parser.parse_args()
    main(args)