import math
from dataclasses import dataclass
from typing import Optional, Dict
import torch
from einops import rearrange
from torch import Tensor, nn

from flux.math import attention, rope
from scaling_cache.utils import taylor_formula, derivative_approximation, taylor_cache_init
from scaling_cache.utils import interleaved_alpha_update, interleaved_cache_update, interleaved_error_update, scaling_formula, dynamic_error_init, relative_l1_error

def double_stream_block_forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, cache_dic: Dict, current: Dict, update_cache=False, **kwargs) -> tuple[Tensor, Tensor]:
    cache_dic["cal_amount"][current['stream']] += 1
    img_mod1, img_mod2 = self.img_mod(vec)
    txt_mod1, txt_mod2 = self.txt_mod(vec)

    # prepare image for attention
    img_modulated = self.img_norm1(img)
    img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
    img_qkv = self.img_attn.qkv(img_modulated)
    img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
    
    img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

    # prepare txt for attention
    txt_modulated = self.txt_norm1(txt)
    txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
    txt_qkv = self.txt_attn.qkv(txt_modulated)
    txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
    
    txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

    # run actual attention
    q = torch.cat((txt_q, img_q), dim=2)
    k = torch.cat((txt_k, img_k), dim=2)
    v = torch.cat((txt_v, img_v), dim=2)

    attn = attention(q, k, v, pe=pe)

    txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

    # calculate the img bloks
    img_attn_out = self.img_attn.proj(img_attn)
    if update_cache:
        if cache_dic['taylor_cache']:
            current['module'] = 'img-attn-taylor'
            taylor_cache_init(cache_dic=cache_dic, current=current)
            derivative_approximation(cache_dic=cache_dic, current=current, feature=img_attn_out)
        elif cache_dic['scaling_cache']:
            current['module'] = "img-attn"
            interleaved_error_update(cache_dic=cache_dic, current=current, feature=img_attn_out)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=img_attn_out)

    img = img + img_mod1.gate * img_attn_out
    
    img_mlp_out = self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
    if update_cache:
        if cache_dic['taylor_cache']:
            current['module'] = 'img-mlp-taylor'
            taylor_cache_init(cache_dic=cache_dic, current=current)
            derivative_approximation(cache_dic=cache_dic, current=current, feature=img_mlp_out)
        elif cache_dic['scaling_cache']:
            current['module'] = 'img-mlp'
            interleaved_error_update(cache_dic=cache_dic, current=current, feature=img_mlp_out)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=img_mlp_out)

    img = img + img_mod2.gate * img_mlp_out
    

    # calculate the txt bloks
    txt_attn_out = self.txt_attn.proj(txt_attn)
    if update_cache:
        if cache_dic['taylor_cache']:
            current['module'] = 'txt-attn-taylor'
            taylor_cache_init(cache_dic=cache_dic, current=current)
            derivative_approximation(cache_dic=cache_dic, current=current, feature=txt_attn_out)
        elif cache_dic['scaling_cache']:
            current['module'] = 'txt-attn'
            interleaved_error_update(cache_dic=cache_dic, current=current, feature=txt_attn_out)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=txt_attn_out)

    txt = txt + txt_mod1.gate * txt_attn_out

    txt_mlp_out = self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
    if update_cache:
        if cache_dic['taylor_cache']:
            current['module'] = 'txt-mlp-taylor'
            taylor_cache_init(cache_dic=cache_dic, current=current)
            derivative_approximation(cache_dic=cache_dic, current=current, feature=txt_mlp_out)
        elif cache_dic['scaling_cache']:
            current['module'] = 'txt-mlp'
            interleaved_error_update(cache_dic=cache_dic, current=current, feature=txt_mlp_out)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=txt_mlp_out)


    txt = txt + txt_mod2.gate * txt_mlp_out

    return img, txt

def double_stream_block_taylor_forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, cache_dic: Dict, current: Dict, **kwargs) -> tuple[Tensor, Tensor]:

    if current['type'] == 'full':    
        img, txt = double_stream_block_forward(
            self,
            img,
            txt,
            vec,
            pe,
            cache_dic,
            current,
            update_cache=True
        )

    elif current['type'] == 'Taylor':
        img_mod1, img_mod2 = self.img_mod(vec)
        txt_mod1, txt_mod2 = self.txt_mod(vec)

        # caculate the img bloks
        current['module'] = 'img-attn-taylor'
        img = img + img_mod1.gate * taylor_formula(cache_dic=cache_dic, current=current)

        current['module'] = 'img-mlp-taylor'
        img = img + img_mod2.gate * taylor_formula(cache_dic=cache_dic, current=current)
        
        # caculate the txt bloks
        current['module'] = 'txt-attn-taylor'
        txt = txt + txt_mod1.gate * taylor_formula(cache_dic=cache_dic, current=current)

        current['module'] = 'txt-mlp-taylor'
        txt = txt + txt_mod2.gate * taylor_formula(cache_dic=cache_dic, current=current)
        
    else:
        raise ValueError(f"Unknown cache type {current['type']}")
        
    return img, txt

def double_stream_block_scaling_forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, cache_dic: Dict, current: Dict, **kwargs) -> tuple[Tensor, Tensor]:
    if current['type'] == 'full':
        img, txt = double_stream_block_forward(
            self,
            img,
            txt,
            vec,
            pe,
            cache_dic,
            current,
            update_cache=True
        )
    elif current['type'] == 'Scaling':
        not_use_cache = current['layer'] > cache_dic['cache']["activated_layers"][current['stream']][current['step']]
        if not_use_cache:
            img, txt = double_stream_block_forward(
                self,
                img,
                txt,
                vec,
                pe,
                cache_dic,
                current,
                update_cache=True
            )
        else:
            img_mod1, img_mod2 = self.img_mod(vec)
            txt_mod1, txt_mod2 = self.txt_mod(vec)

            difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2]
            layer_dict = cache_dic['cache'][-1][current['stream']][current['layer']]

            # caculate the img bloks
            if cache_dic['update_alpha']:
                # prepare image for attention
                img_modulated = self.img_norm1(img)
                img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
                img_qkv = self.img_attn.qkv(img_modulated)
                img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
                
                img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

                # prepare txt for attention
                txt_modulated = self.txt_norm1(txt)
                txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
                txt_qkv = self.txt_attn.qkv(txt_modulated)
                txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
                
                txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

                # run actual attention
                q = torch.cat((txt_q, img_q), dim=2)
                k = torch.cat((txt_k, img_k), dim=2)
                v = torch.cat((txt_v, img_v), dim=2)

                attn = attention(q, k, v, pe=pe)
                txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

            current['module'] = 'img-attn'
            if cache_dic['update_alpha']:
                img_attn_out = self.img_attn.proj(img_attn)
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=img_attn_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=img_attn_out)
            else:
                img_attn_out = scaling_formula(cache_dic, current)

            img = img + img_mod1.gate * img_attn_out

            current['module'] = 'img-mlp'
            if cache_dic['update_alpha']:
                img_mlp_out = self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=img_mlp_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=img_mlp_out)
            else:
                img_mlp_out = scaling_formula(cache_dic, current)

            img = img + img_mod2.gate * img_mlp_out

            current['module'] = 'txt-attn'
            if cache_dic['update_alpha']:
                txt_attn_out = self.txt_attn.proj(txt_attn)
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=txt_attn_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=txt_attn_out)
            else:
                txt_attn_out = scaling_formula(cache_dic, current)

            txt = txt + txt_mod1.gate * txt_attn_out

            current['module'] = 'txt-mlp'
            if cache_dic['update_alpha']:
                txt_mlp_out = self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=txt_mlp_out)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=txt_mlp_out)
            else:
                txt_mlp_out = scaling_formula(cache_dic, current)
                
            txt = txt + txt_mod2.gate * txt_mlp_out
    else:
        raise ValueError(f"Unknown cache type {current['type']}")
    
    return img, txt