from typing import Any, Dict, Optional, Tuple, Union
from pipeline_flux import FluxPipeline
from diffusers.models import FluxTransformer2DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
import torch
import numpy as np
import os
import tqdm
from safetensors.torch import save_file, load_file
import json
import pandas as pd
import os
import sys
from adjust_utils import FlowMatchEulerDiscreteScheduler


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def ertacache_forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
        pooled_projections: torch.Tensor = None,
        timestep: torch.LongTensor = None,
        img_ids: torch.Tensor = None,
        txt_ids: torch.Tensor = None,
        guidance: torch.Tensor = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        controlnet_block_samples=None,
        controlnet_single_block_samples=None,
        return_dict: bool = True,
        controlnet_blocks_repeat: bool = False,
    ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
        """
        The [`FluxTransformer2DModel`] forward method.

        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
                Input `hidden_states`.
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
            pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
                from the embeddings of input conditions.
            timestep ( `torch.LongTensor`):
                Used to indicate denoising step.
            block_controlnet_hidden_states: (`list` of `torch.Tensor`):
                A list of tensors that if specified are added to the residuals of transformer blocks.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """

        if int(os.getenv('SAVE_RESIDUAL',0))==1:
            real_cache_dict = load_file(os.environ['real_cache_dict'])

        if self.cnt  >= self.num_steps:
            self.cnt = 0        
            self.skip_cnt = 0   
            self.cache_list = []
        
        
        if joint_attention_kwargs is not None:
            joint_attention_kwargs = joint_attention_kwargs.copy()
            lora_scale = joint_attention_kwargs.pop("scale", 1.0)
        else:
            lora_scale = 1.0

        if USE_PEFT_BACKEND:
            # weight the lora layers by setting `lora_scale` for each PEFT layer
            scale_lora_layers(self, lora_scale)
        else:
            if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
                logger.warning(
                    "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
                )

        hidden_states = self.x_embedder(hidden_states)

        timestep = timestep.to(hidden_states.dtype) * 1000
        if guidance is not None:
            guidance = guidance.to(hidden_states.dtype) * 1000
        else:
            guidance = None

        temb = (
            self.time_text_embed(timestep, pooled_projections)
            if guidance is None
            else self.time_text_embed(timestep, guidance, pooled_projections)
        )
        encoder_hidden_states = self.context_embedder(encoder_hidden_states)

        if txt_ids.ndim == 3:
            logger.warning(
                "Passing `txt_ids` 3d torch.Tensor is deprecated."
                "Please remove the batch dimension and pass it as a 2d torch Tensor"
            )
            txt_ids = txt_ids[0]
        if img_ids.ndim == 3:
            logger.warning(
                "Passing `img_ids` 3d torch.Tensor is deprecated."
                "Please remove the batch dimension and pass it as a 2d torch Tensor"
            )
            img_ids = img_ids[0]

        ids = torch.cat((txt_ids, img_ids), dim=0)
        image_rotary_emb = self.pos_embed(ids)

        if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
            ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
            ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
            joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})

        if self.enable_teacache:
            inp = hidden_states.clone()
            temb_ = temb.clone()
            modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_)
            old_previous_inp = self.previous_modulated_input.clone() if self.cnt >0 else None
            if self.cnt == 0 or self.cnt == self.num_steps-1:
                should_calc = True
                self.accumulated_rel_l1_distance = 0
            else: 
                coefficients = [4.98651651e+02, -2.83781631e+02,  5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
                rescale_func = np.poly1d(coefficients)
                self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
                if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
                    should_calc = False
                else:
                    should_calc = True
                    self.accumulated_rel_l1_distance = 0
            self.previous_modulated_input = modulated_inp 
          
        

        if self.enable_teacache:
            if not should_calc:
                hidden_states += self.previous_residual
            else:
                ori_hidden_states = hidden_states.clone()
                if self.enable_ertacache:
                    if self.cnt in self.skip_list:
                        hidden_states = ori_hidden_states + self.previous_residual
                        self.skip_cnt += 1
                        print(f"skip timestep:{self.cnt}")
                     

                    else:
                        for index_block, block in enumerate(self.transformer_blocks):
                            if torch.is_grad_enabled() and self.gradient_checkpointing:

                                def create_custom_forward(module, return_dict=None):
                                    def custom_forward(*inputs):
                                        if return_dict is not None:
                                            return module(*inputs, return_dict=return_dict)
                                        else:
                                            return module(*inputs)

                                    return custom_forward

                                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                                encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
                                    create_custom_forward(block),
                                    hidden_states,
                                    encoder_hidden_states,
                                    temb,
                                    image_rotary_emb,
                                    **ckpt_kwargs,
                                )

                            else:
                                encoder_hidden_states, hidden_states = block(
                                    hidden_states=hidden_states,
                                    encoder_hidden_states=encoder_hidden_states,
                                    temb=temb,
                                    image_rotary_emb=image_rotary_emb,
                                    joint_attention_kwargs=joint_attention_kwargs,
                                )

                            # controlnet residual
                            if controlnet_block_samples is not None:
                                interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
                                interval_control = int(np.ceil(interval_control))
                                # For Xlabs ControlNet.
                                if controlnet_blocks_repeat:
                                    hidden_states = (
                                        hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
                                    )
                                else:
                                    hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
                        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
                        

                        for index_block, block in enumerate(self.single_transformer_blocks):
                            if torch.is_grad_enabled() and self.gradient_checkpointing:

                                def create_custom_forward(module, return_dict=None):
                                    def custom_forward(*inputs):
                                        if return_dict is not None:
                                            return module(*inputs, return_dict=return_dict)
                                        else:
                                            return module(*inputs)

                                    return custom_forward

                                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                                hidden_states = torch.utils.checkpoint.checkpoint(
                                    create_custom_forward(block),
                                    hidden_states,
                                    temb,
                                    image_rotary_emb,
                                    **ckpt_kwargs,
                                )

                            else:
                                hidden_states = block(
                                    hidden_states=hidden_states,
                                    temb=temb,
                                    image_rotary_emb=image_rotary_emb,
                                    joint_attention_kwargs=joint_attention_kwargs,
                                )

                            # controlnet residual
                            if controlnet_single_block_samples is not None:
                                interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
                                interval_control = int(np.ceil(interval_control))
                                hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
                                    hidden_states[:, encoder_hidden_states.shape[1] :, ...]
                                    + controlnet_single_block_samples[index_block // interval_control]
                                )

                        hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
                        

                        if self.cnt > 0:
                        
                            if int(os.getenv('SAVE_RESIDUAL',0))==1:
                                print('##### use real_residual')
                                real_residual = real_cache_dict[f'output_{self.cnt}'].to(device = ori_hidden_states.device, dtype = ori_hidden_states.dtype)
                                l1_residual = (hidden_states - ori_hidden_states - (self.previous_residual)).abs().mean() / real_residual.abs().mean() 
                                l1_inp = (old_previous_inp - self.previous_modulated_input).abs().mean() / old_previous_inp.abs().mean()
                                real_l1_distance = l1_residual + l1_inp  ##### real_residual + v5

                            else:
                                l1_residual = (hidden_states - ori_hidden_states - (self.previous_residual)).abs().mean() / self.previous_residual.abs().mean() 
                                l1_inp = (old_previous_inp - self.previous_modulated_input).abs().mean() / old_previous_inp.abs().mean()
                                real_l1_distance = l1_residual + l1_inp  ##### v5
                            
                            print(f"real l1_distance:{real_l1_distance},l1_residual:{l1_residual}, l1_inp:{l1_inp}")
                            

                        if self.calibrate:
                            if self.cnt > 0 and self.cnt < self.num_steps-1 and real_l1_distance < self.calibrate_rel_l1_thresh:
                                self.cache_list.append(self.cnt)
                                hidden_states = ori_hidden_states + self.previous_residual
                                self.skip_cnt += 1
                                print(f"skip timestep:{self.cnt}")
                            else:
                                self.previous_residual =  hidden_states - ori_hidden_states
                        
                        else:
                            self.previous_residual = hidden_states - ori_hidden_states
                        
                    
            

        
        #### no cache, original
        else:
            for index_block, block in enumerate(self.transformer_blocks):
                if torch.is_grad_enabled() and self.gradient_checkpointing:

                    def create_custom_forward(module, return_dict=None):
                        def custom_forward(*inputs):
                            if return_dict is not None:
                                return module(*inputs, return_dict=return_dict)
                            else:
                                return module(*inputs)

                        return custom_forward

                    ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                    encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(block),
                        hidden_states,
                        encoder_hidden_states,
                        temb,
                        image_rotary_emb,
                        **ckpt_kwargs,
                    )

                else:
                    encoder_hidden_states, hidden_states = block(
                        hidden_states=hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        temb=temb,
                        image_rotary_emb=image_rotary_emb,
                        joint_attention_kwargs=joint_attention_kwargs,
                    )

                # controlnet residual
                if controlnet_block_samples is not None:
                    interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
                    interval_control = int(np.ceil(interval_control))
                    # For Xlabs ControlNet.
                    if controlnet_blocks_repeat:
                        hidden_states = (
                            hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
                        )
                    else:
                        hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
            hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

            for index_block, block in enumerate(self.single_transformer_blocks):
                if torch.is_grad_enabled() and self.gradient_checkpointing:

                    def create_custom_forward(module, return_dict=None):
                        def custom_forward(*inputs):
                            if return_dict is not None:
                                return module(*inputs, return_dict=return_dict)
                            else:
                                return module(*inputs)

                        return custom_forward

                    ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(block),
                        hidden_states,
                        temb,
                        image_rotary_emb,
                        **ckpt_kwargs,
                    )

                else:
                    hidden_states = block(
                        hidden_states=hidden_states,
                        temb=temb,
                        image_rotary_emb=image_rotary_emb,
                        joint_attention_kwargs=joint_attention_kwargs,
                    )

                # controlnet residual
                if controlnet_single_block_samples is not None:
                    interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
                    interval_control = int(np.ceil(interval_control))
                    hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
                        hidden_states[:, encoder_hidden_states.shape[1] :, ...]
                        + controlnet_single_block_samples[index_block // interval_control]
                    )

            hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

        hidden_states = self.norm_out(hidden_states, temb)
        output = self.proj_out(hidden_states)
        self.cnt += 1 
        if self.cnt  == self.num_steps:
            print(f"skip {self.skip_cnt} steps, skip ratio: {self.skip_cnt / self.num_steps}")
            print(f"cache_list: {self.cache_list}")
        
        if USE_PEFT_BACKEND:
            # remove `lora_scale` from each PEFT layer
            unscale_lora_layers(self, lora_scale)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)


def generate_func(pipeline, prompt_list, output_dir, num_inference_steps=25, tensor_folder_name=None, loop: int = 1, kwargs: dict = {}, num_file : int=1000):
    os.makedirs(output_dir, exist_ok=True)
    if getattr(pipeline,"save_error",False):
        tensor_folder_path = os.path.join(output_dir,tensor_folder_name)
        os.makedirs(tensor_folder_path,exist_ok=True)
    cnt=0

    for prompt in tqdm.tqdm(prompt_list):
        for l in range(loop):
            len_png = len([i for i in os.listdir(output_dir) if '.png' in i])
            if len_png == num_file:
                return None

            # else:
            if os.path.exists(os.path.join(output_dir,f"{prompt[0:50]}-{l}.png")):
                continue

            if int(os.getenv('SAVE_RESIDUAL',0))==1:
                os.environ['real_cache_dict'] = os.path.join('./sample/original_20coco/real_residual',f"{prompt[0:50]}-{l}.safetensors")

            pipeline.transformer.__class__.cnt = 0
            pipeline.transformer.__class__.skip_cnt = 0
            pipeline.transformer.__class__.cache_list = []   
            os.environ['ADJUST_PATH'] = os.path.join(os.getenv('ADJUST_FOLDER',''),f'K_hw_{l}.safetensors')
            try:
                img = pipeline(
                    prompt, 
                    num_inference_steps=num_inference_steps,
                    generator=torch.Generator("cpu").manual_seed(l)
                    ).images[0]
                img.save(os.path.join(output_dir,f"{prompt[0:50]}-{l}.png"))
                if getattr(pipeline,"save_error",False):
                    print('#### save file to ', os.path.join(tensor_folder_path, f"{cnt}_seed{l}.safetensors"))
                    save_file(pipeline.__class__.save_cache, os.path.join(tensor_folder_path, f"{cnt}_seed{l}.safetensors"))
                cnt+=1
            except:
                pass

                

                
            
            


def calibrate_ertacache(prompt_list,loop=1):
    FluxTransformer2DModel.forward = ertacache_forward
    num_inference_steps = 28
    seed = 42
    prompt = "An image of a squirrel in Picasso style"
    pipeline = FluxPipeline.from_pretrained("black-forest-labs-FLUX.1-dev", torch_dtype=torch.float16)
    pipeline.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

    # TeaCache
    pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config('black-forest-labs-FLUX.1-dev/scheduler')
    pipeline.transformer.__class__.forward = ertacache_forward
    pipeline.transformer.__class__.enable_teacache = True
    pipeline.transformer.__class__.enable_ertacache = True
    pipeline.transformer.__class__.calibrate = True
    pipeline.transformer.__class__.calibrate_rel_l1_thresh = float(os.getenv('CALIBRATE_REL_L1_THRESH',-1000))
    pipeline.transformer.__class__.cnt = 0
    pipeline.transformer.__class__.skip_cnt = 0
    pipeline.transformer.__class__.num_steps = num_inference_steps
    pipeline.transformer.__class__.rel_l1_thresh = -100 # 0.25 for 1.5x speedup, 0.4 for 1.8x speedup, 0.6 for 2.0x speedup, 0.8 for 2.25x speedup
    pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
    pipeline.transformer.__class__.previous_modulated_input = None
    pipeline.transformer.__class__.previous_residual = None
    pipeline.transformer.__class__.cache_list = []   
    pipeline.transformer.__class__.skip_list = []   
    pipeline.__class__.save_error = int(os.getenv("SAVE_ERROR",0)) == 1
    pipeline.__class__.save_cache = {}
    pipeline.__class__.skip_list = pipeline.transformer.__class__.skip_list
    tensor_folder_name = 'ertacache_tensor' if pipeline.transformer.__class__.enable_ertacache else 'tensor'
    output_dir = os.getenv('OUTPUT_DIR',"./samples_flux/tmp")     
    os.makedirs(output_dir,exist_ok=True)
    pipeline.to("cuda")
    generate_func(pipeline, prompt_list, output_dir, num_inference_steps=num_inference_steps, tensor_folder_name=tensor_folder_name, loop= loop)
    
        

def eval_ertacache(prompt_list,loop=1,num_file=1000):
    FluxTransformer2DModel.forward = ertacache_forward
    num_inference_steps = 28
    seed = 42
    pipeline = FluxPipeline.from_pretrained("black-forest-labs-FLUX.1-dev", torch_dtype=torch.float16)
    pipeline.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

    # TeaCache
    pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config('black-forest-labs-FLUX.1-dev/scheduler')
    pipeline.transformer.__class__.forward = ertacache_forward
    pipeline.transformer.__class__.enable_teacache = int(os.getenv('ERTACACHE',0))==1
    pipeline.transformer.__class__.enable_ertacache = int(os.getenv('ERTACACHE',0))==1
    pipeline.transformer.__class__.calibrate = False
    pipeline.transformer.__class__.calibrate_rel_l1_thresh = float(os.getenv('CALIBRATE_REL_L1_THRESH',-1000))
    pipeline.transformer.__class__.cnt = 0
    pipeline.transformer.__class__.skip_cnt = 0
    pipeline.transformer.__class__.num_steps = num_inference_steps
    pipeline.transformer.__class__.rel_l1_thresh = -100 
    pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
    pipeline.transformer.__class__.previous_modulated_input = None
    pipeline.transformer.__class__.previous_residual = None
    pipeline.transformer.__class__.cache_list = []   
    pipeline.transformer.__class__.skip_list = json.loads(os.getenv('SKIP_LIST','[]'))
    pipeline.__class__.save_error = int(os.getenv("SAVE_ERROR",0)) == 1
    pipeline.__class__.save_cache = {}
    pipeline.__class__.skip_list = pipeline.transformer.__class__.skip_list
    tensor_folder_name = 'ertacache_tensor' if pipeline.transformer.__class__.enable_ertacache else 'tensor'
    output_dir = os.getenv('OUTPUT_DIR',"./samples_flux/tmp")     
    os.makedirs(output_dir,exist_ok=True)
    pipeline.to("cuda")
    loop = 1

    generate_func(pipeline, prompt_list, output_dir, num_inference_steps=num_inference_steps, tensor_folder_name=tensor_folder_name, loop= 1, num_file=num_file)
    

def read_prompt_list(prompt_list_path='mscoco_val2014_30k/metadata.csv'):
    return pd.read_csv(prompt_list_path)['text'].tolist()


if __name__ == "__main__":
    num = int(os.getenv('DATA_TEST',0))
    prompt_list = read_prompt_list()[:num]
    eval_ertacache(prompt_list,loop=1,num_file=num)

# python main.py personal -p pengxurui.pxr -n ERTACache -b master -t /mnt/bn/bytenn-yg2/pxr/ERTACache -s OPENSOURCE
