import math
from dataclasses import dataclass
from typing import Optional, Dict
import torch
from einops import rearrange
from torch import Tensor, nn

from flux.math import attention, rope

from scaling_cache.utils import taylor_formula, derivative_approximation, taylor_cache_init
from scaling_cache.utils import interleaved_cache_update, interleaved_alpha_update, interleaved_error_update, scaling_formula, relative_l1_error

def single_stream_block_forward(self, x: Tensor, vec: Tensor, pe: Tensor, cache_dic: Dict, current: Dict, update_cache=False, **kwargs) -> Tensor:
    cache_dic["cal_amount"][current['stream']] += 1
    mod, _ = self.modulation(vec)
    x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
    qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
    q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
    q, k = self.norm(q, k, v)
    # compute attention
    attn = attention(q, k, v, pe=pe)
    # compute activation in mlp stream, cat again and run second linear layer
    output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
    
    if update_cache:
        if cache_dic['taylor_cache']:
            current['module'] = 'total-taylor'
            taylor_cache_init(cache_dic=cache_dic, current=current)
            derivative_approximation(cache_dic=cache_dic, current=current, feature=output)
        elif cache_dic['scaling_cache']:
            current['module'] = 'total'
            interleaved_error_update(cache_dic=cache_dic, current=current, feature=output)
            interleaved_cache_update(cache_dic=cache_dic, current=current, feature=output)

    return x + mod.gate * output

def single_stream_block_taylor_forward(self, x: Tensor, vec: Tensor, pe: Tensor, cache_dic: Dict, current: Dict, **kwargs) -> Tensor:

    if current['type'] == 'full':
        return single_stream_block_forward(self, x, vec, pe, cache_dic, current, update_cache=True)

    elif current['type'] == 'Taylor':
        mod, _ = self.modulation(vec)
        current['module'] = 'total-taylor'
        output = taylor_formula(cache_dic=cache_dic, current=current)
        return x + mod.gate * output

    else:
        raise ValueError(f"Unknown cache type {current['type']}")


def single_stream_block_scaling_forward(self, x: Tensor, vec: Tensor, pe: Tensor, cache_dic: Dict, current: Dict, **kwargs) -> Tensor:

    if current['type'] == 'full':
        return single_stream_block_forward(self, x, vec, pe, cache_dic, current, update_cache=True)
    elif current['type'] == 'Scaling':
        not_use_cache = current['layer'] > cache_dic['cache']["activated_layers"][current['stream']][current['step']]
        if not_use_cache:
            return single_stream_block_forward(self, x, vec, pe, cache_dic, current, update_cache=True)
        else:
            mod, _ = self.modulation(vec)
            current['module'] = 'total'
            if cache_dic['update_alpha']:
                x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
                qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
                q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
                q, k = self.norm(q, k, v)
                # compute attention
                attn = attention(q, k, v, pe=pe)
                # compute activation in mlp stream, cat again and run second linear layer
                output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
                interleaved_alpha_update(cache_dic=cache_dic, current=current, target_feature=output)
                interleaved_cache_update(cache_dic=cache_dic, current=current, feature=output)
            else:
                output = scaling_formula(cache_dic=cache_dic, current=current)
        return x + mod.gate * output

    else:
        raise ValueError("Unknown cache type.") 