import math

import torch
import torch.cuda.amp as amp
import torch.nn as nn
import numpy as np

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from wan.modules import WanModel
from wan.modules.model import sinusoidal_embedding_1d, WanAttentionBlock
from wan.modules.attention import flash_attention

from scaling_cache.utils import taylor_cache_init, derivative_approximation, taylor_formula
from scaling_cache.utils import interleaved_cache_update, interleaved_alpha_update, interleaved_error_update, scaling_formula
from xfuser.core.distributed import (
    get_sequence_parallel_rank,
    get_sequence_parallel_world_size,
    get_sp_group,
)
import matplotlib.pyplot as plt
import seaborn as sns

def wan_block_forward(
    self:WanAttentionBlock,
    x,
    e,
    seq_lens,
    grid_sizes,
    freqs,
    context,
    context_lens,
    cache_dic,
    current,
    update_cache:bool = False
):
    r"""
    Args:
        x(Tensor): Shape [B, L, C]
        e(Tensor): Shape [B, 6, C]
        seq_lens(Tensor): Shape [B], length of each sequence in batch
        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
    """
    assert e.dtype == torch.float32
    with torch.amp.autocast('cuda', dtype=torch.float32):
        e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
    assert e[0].dtype == torch.float32
    
    cache_dic["cal_amount"][current['stream']] += 1

    current_step = current['step']
    current_layer = current['layer']

    sa = self.self_attn(
        self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
        seq_lens, grid_sizes, freqs)
    
    if update_cache:
        if cache_dic['taylor_cache']:
            current['module'] = 'self-attention-taylor'
            taylor_cache_init(cache_dic=cache_dic, current=current)
            derivative_approximation(cache_dic=cache_dic, current=current, feature=sa)
        elif cache_dic['scaling_cache']:
            current['module'] = 'self-attention'
            interleaved_error_update(cache_dic=cache_dic, current=current, feature=sa)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=sa)

    with torch.amp.autocast('cuda', dtype=torch.float32):
        x = x + sa * e[2].squeeze(2)

    ca = self.cross_attn(self.norm3(x), context, context_lens)

    if update_cache:
        if cache_dic['taylor_cache']:
            current['module'] = 'cross-attention-taylor'
            taylor_cache_init(cache_dic=cache_dic, current=current)
            derivative_approximation(cache_dic=cache_dic, current=current, feature=ca)
        elif cache_dic['scaling_cache']:
            current['module'] = 'cross-attention'
            interleaved_error_update(cache_dic=cache_dic, current=current, feature=ca)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=ca)
            
    x = x + ca

    ffn = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))

    if update_cache:
        if cache_dic['taylor_cache']:
            current['module'] = 'ffn-taylor'
            taylor_cache_init(cache_dic=cache_dic, current=current)
            derivative_approximation(cache_dic=cache_dic, current=current, feature=ffn)
        elif cache_dic['scaling_cache']:
            current['module'] = 'ffn'
            interleaved_error_update(cache_dic=cache_dic, current=current, feature=ffn)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=ffn)

    with torch.amp.autocast('cuda', dtype=torch.float32):
        x = x + ffn * e[5].squeeze(2)
    return x

def wan_block_taylor_forward(
    self:WanAttentionBlock,
    x,
    e,
    seq_lens,
    grid_sizes,
    freqs,
    context,
    context_lens,
    cache_dic,
    current
):
    r"""
    Args:
        x(Tensor): Shape [B, L, C]
        e(Tensor): Shape [B, 6, C]
        seq_lens(Tensor): Shape [B], length of each sequence in batch
        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
    """

    if current['type'] == 'full':
        update_cache = (current['step'] < cache_dic['num_steps'] - cache_dic['last_enhance'])
        x = wan_block_forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, cache_dic, current, update_cache=update_cache)
    elif current['type'] == 'Taylor':
        assert e.dtype == torch.float32
        with torch.amp.autocast('cuda', dtype=torch.float32):
            e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
        assert e[0].dtype == torch.float32

        layer_dict = cache_dic['cache'][-1][current['stream']][current['layer']]
        current['module'] = 'self-attention-taylor'
        sa  = taylor_formula(cache_dic=cache_dic, current=current)
        with torch.amp.autocast('cuda', dtype=torch.float32):
            x = x + sa * e[2].squeeze(2)
        current['module'] = 'cross-attention-taylor'
        ca  = taylor_formula(cache_dic=cache_dic, current=current)
        x = x + ca
        current['module'] = 'ffn-taylor'
        ffn = taylor_formula(cache_dic=cache_dic, current=current)
        with torch.amp.autocast('cuda', dtype=torch.float32):
            x = x + ffn * e[5].squeeze(2)
    else:
        raise ValueError(f"Not supported type: {current['type']}")
    return x

def wan_block_scaling_forward(
    self:WanAttentionBlock,
    x,
    e,
    seq_lens,
    grid_sizes,
    freqs,
    context,
    context_lens,
    cache_dic,
    current
):
    r"""
    Args:
        x(Tensor): Shape [B, L, C]
        e(Tensor): Shape [B, 6, C]
        seq_lens(Tensor): Shape [B], length of each sequence in batch
        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
    """
    if current['type'] == 'full':
        update_cache = (current['step'] < cache_dic['num_steps'] - cache_dic['last_enhance'])
        x = wan_block_forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, cache_dic, current, update_cache=update_cache)
    elif current['type'] == 'Scaling':
        assert e.dtype == torch.float32
        with torch.amp.autocast('cuda', dtype=torch.float32):
            e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
        assert e[0].dtype == torch.float32
        
        difference_distance = (current['step'] - current['activated_steps'][-1]) / (current['activated_steps'][-1] - current['activated_steps'][-2])

        layer_dict = cache_dic['cache'][-1][current['stream']][current['layer']]

        current['module'] = 'self-attention'
        if cache_dic['update_alpha']:
            sa = self.self_attn(self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), seq_lens, grid_sizes, freqs)
            interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=sa)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=sa)
        else:
            sa = scaling_formula(cache_dic=cache_dic, current=current)

        with torch.amp.autocast('cuda', dtype=torch.float32):
            x = x + sa * e[2].squeeze(2)

        current['module'] = 'cross-attention'
        if cache_dic['update_alpha']:
            ca = self.cross_attn(self.norm3(x), context, context_lens)
            interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=ca)
        else:
            ca = scaling_formula(cache_dic=cache_dic, current=current)
        
        x = x + ca
        
        current['module'] = 'ffn'
        if cache_dic['update_alpha']:
            ffn = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
            interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=ffn)
        else:
            ffn = scaling_formula(cache_dic=cache_dic, current=current)
        
        with torch.amp.autocast('cuda', dtype=torch.float32):
            x = x + ffn * e[5].squeeze(2)
    else:
        raise ValueError(f"Not supported type: {current['type']}")
    return x