import math
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import Attention
import qdiff.flatquant.globalvar as globalvar


def pad_to_multiple(query, key, multiple=128):

    batch_size, num_heads, seq_length, head_dim = query.shape
    

    if seq_length % multiple == 0:
        return query, key, 0
        
    padding_length = multiple - (seq_length % multiple)
    

    query_padding = torch.zeros(batch_size, num_heads, padding_length, head_dim, 
                              dtype=query.dtype, device=query.device)
    key_padding = torch.zeros(batch_size, num_heads, padding_length, head_dim,
                            dtype=key.dtype, device=key.device)
    

    padded_query = torch.cat([query, query_padding], dim=2)
    padded_key = torch.cat([key, key_padding], dim=2)
    
    return padded_query, padded_key, padding_length

class WanAttnProcessor2_0_Preprocessor:
    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")

    def scaled_dot_product_attention_to_list(self, query, key) -> torch.Tensor:

        scale_factor = 1 / math.sqrt(query.size(-1))
        batch_size, num_heads, seq_length, head_dim = query.shape


        query_low_res, key_low_res, padding_length = pad_to_multiple(query, key)

        new_seq_length = (seq_length + padding_length) // 128
        query_low_res = query_low_res.as_strided(
            (batch_size, num_heads, new_seq_length, 128, head_dim),
            (num_heads * seq_length * head_dim, seq_length * head_dim, 128 * head_dim, head_dim, 1)
        ).mean(dim=3)
        key_low_res = key_low_res.as_strided(
            (batch_size, num_heads, new_seq_length, 128, head_dim),
            (num_heads * seq_length * head_dim, seq_length * head_dim, 128 * head_dim, head_dim, 1)
        ).mean(dim=3)
        

        attn_low_res = torch.matmul(query_low_res, key_low_res.transpose(-2, -1))


        chunk_size = 2048 
        num_chunks = (seq_length + chunk_size - 1) // chunk_size
        
        attn_weight = torch.zeros(batch_size, seq_length, device=query.device)
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min((i + 1) * chunk_size, seq_length)
            

            chunk_query = query[:, :, start_idx:end_idx]
            chunk_attn = chunk_query @ key.transpose(-2, -1)
            chunk_attn.mul_(scale_factor)
            chunk_attn = torch.softmax(chunk_attn, dim=-1)
            chunk_attn = chunk_attn.sum(dim=1)
            chunk_attn = chunk_attn.sum(dim=1)
            attn_weight += chunk_attn

            del chunk_query, chunk_attn
            torch.cuda.empty_cache()
        

        batch_indices = torch.arange(batch_size, device=query.device)[:, None]


        top_token_indices = torch.topk(attn_weight, k=256, dim=-1)[1]

        top_query = query[batch_indices, :, top_token_indices].transpose(1,2)
        top_key = key[batch_indices, :, top_token_indices].transpose(1,2)

        top_attn = torch.matmul(top_query, top_key.transpose(-2, -1))

        globalvar.add_attn_distill((attn_low_res, top_attn, top_token_indices))

        del query_low_res, key_low_res
        torch.cuda.empty_cache()
        
        raise Exception("stop here")

    
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        encoder_hidden_states_img = None
        if attn.add_k_proj is not None:
            encoder_hidden_states_img = encoder_hidden_states[:, :257]
            encoder_hidden_states = encoder_hidden_states[:, 257:]
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states

        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

        if rotary_emb is not None:

            def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
                x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
                x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
                return x_out.type_as(hidden_states)

            query = apply_rotary_emb(query, rotary_emb)
            key = apply_rotary_emb(key, rotary_emb)

        hidden_states_img = None
        if encoder_hidden_states_img is not None:
            key_img = attn.add_k_proj(encoder_hidden_states_img)
            key_img = attn.norm_added_k(key_img)
            value_img = attn.add_v_proj(encoder_hidden_states_img)

            key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
            value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)

            hidden_states_img = F.scaled_dot_product_attention(
                query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
            )
            hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
            hidden_states_img = hidden_states_img.type_as(query)

        self.scaled_dot_product_attention_to_list(query, key)
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
        hidden_states = hidden_states.type_as(query)

        if hidden_states_img is not None:
            hidden_states = hidden_states + hidden_states_img

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states
    
class WanAttnProcessor2_0_Trainer:
    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")

    def scaled_dot_product_attention_to_list(self, query, key) -> torch.Tensor:

        scale_factor = 1 / math.sqrt(query.size(-1))
        batch_size, num_heads, seq_length, head_dim = query.shape


        query_low_res, key_low_res, padding_length = pad_to_multiple(query, key)

        new_seq_length = (seq_length + padding_length) // 128
        query_low_res = query_low_res.as_strided(
            (batch_size, num_heads, new_seq_length, 128, head_dim),
            (num_heads * seq_length * head_dim, seq_length * head_dim, 128 * head_dim, head_dim, 1)
        ).mean(dim=3)
        key_low_res = key_low_res.as_strided(
            (batch_size, num_heads, new_seq_length, 128, head_dim),
            (num_heads * seq_length * head_dim, seq_length * head_dim, 128 * head_dim, head_dim, 1)
        ).mean(dim=3)
        attn_low_res = torch.matmul(query_low_res, key_low_res.transpose(-2, -1))



        batch_indices = torch.arange(batch_size, device=query.device)[:, None]
        top_token_indices = globalvar.get_current_tok_index()
        top_query = query[batch_indices, :, top_token_indices].transpose(1,2)
        top_key = key[batch_indices, :, top_token_indices].transpose(1,2)
        top_attn = torch.matmul(top_query, top_key.transpose(-2, -1))

        
        globalvar.set_current_attn((attn_low_res, top_attn))


        return
    
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        encoder_hidden_states_img = None
        if attn.add_k_proj is not None:
            encoder_hidden_states_img = encoder_hidden_states[:, :257]
            encoder_hidden_states = encoder_hidden_states[:, 257:]
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states

        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

        if rotary_emb is not None:

            def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
                x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
                x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
                return x_out.type_as(hidden_states)

            query = apply_rotary_emb(query, rotary_emb)
            key = apply_rotary_emb(key, rotary_emb)


        hidden_states_img = None
        if encoder_hidden_states_img is not None:
            key_img = attn.add_k_proj(encoder_hidden_states_img)
            key_img = attn.norm_added_k(key_img)
            value_img = attn.add_v_proj(encoder_hidden_states_img)

            key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
            value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)

            hidden_states_img = F.scaled_dot_product_attention(
                query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
            )
            hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
            hidden_states_img = hidden_states_img.type_as(query)

        self.scaled_dot_product_attention_to_list(query.detach().clone(), key.detach().clone())
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
        hidden_states = hidden_states.type_as(query)

        if hidden_states_img is not None:
            hidden_states = hidden_states + hidden_states_img

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states

