import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.nn import functional as F

from diffusers.models.attention_processor import Attention


class AttentionStore:
    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [], "mid_self": [], "up_self": []}

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if self.cur_att_layer >= 0:
            if attn.shape[1] == np.prod(self.attn_res):
                self.step_store[key].append(attn)

        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers:
            self.cur_att_layer = 0
            self.between_steps()

    def between_steps(self):
        self.attention_store = self.step_store
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = self.attention_store
        return average_attention

    def aggregate_attention(self, from_where: List[str], is_cross: bool = True) -> torch.Tensor:
        out = []
        attention_maps = self.get_average_attention()
        for location in from_where:
            for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
                cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
                out.append(cross_maps)
        out = torch.cat(out, dim=0)
        out = out.sum(0) / out.shape[0]
        return out
    """Aggregates the attention across the different layers and heads at the specified resolution."""
    """def aggregate_attention(self, from_where: List[str], is_cross: bool = True) -> torch.Tensor:
            original = []
            max_res=0
            attention_maps = self.get_average_attention()
            for location in from_where:
                for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
                    res=int(np.sqrt(item.shape[1]))
                    max_res=max(max_res,res)
                    attn_maps = item.reshape(-1, res, res, item.shape[-1])
                    attn_maps=attn_maps.permute(0, 3, 1, 2)
                    original.append(attn_maps)
            out=[]
            for i in range(len(original)):
                resized_attention_maps = F.interpolate(
                    original[i],
                    size=(max_res,max_res),
                    mode='bilinear',
                    align_corners=False
                )
                resized_attention_maps=resized_attention_maps.permute(0, 2, 3, 1)
                #print(resized_attention_maps.shape)
                out.append(resized_attention_maps)

            if not is_cross: #repeat self_attn
                for i in range(len(out)):
                    num_repeat=int(max_res*max_res/out[i].shape[-1])
                    y = torch.repeat_interleave(out[i], num_repeat,dim=-1)
                    #print(y.shape)
                    out[i]=y

            out = torch.cat(out, dim=0)
            out = out.sum(0) / out.shape[0]
            return out"""

    def reset(self):
        self.cur_att_layer = 0
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self, attn_res):
        """
        Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
        process
        """
        self.num_att_layers = -1
        self.cur_att_layer = 0
        self.step_store = self.get_empty_store()
        self.attention_store = {}
        self.curr_step_index = 0
        self.attn_res = attn_res


class AttendExciteAttnProcessor:
    def __init__(self, attnstore, place_in_unet):
        super().__init__()
        self.attnstore = attnstore
        self.place_in_unet = place_in_unet

    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        query = attn.to_q(hidden_states)

        is_cross = encoder_hidden_states is not None
        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)

        # only need to store attention maps during the Attend and Excite process
        if attention_probs.requires_grad:
            self.attnstore(attention_probs, is_cross, self.place_in_unet)

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states
    
