import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from .transformer_2d import Transformer2DModel

import pdb
'''
input: 
    hidden_feature: tensor (bs, f, c_h, w', h')
    ref_feature: tensor (bs, k, c_r, w', h')
    masks: list of tensor (bs, 1, w, h)
output:
    hidden_feature: tensor (bs, f, c, w', h')
'''
class PersonalizedAttention(ModelMixin, ConfigMixin):
    
    @register_to_config
    def __init__(self, level):
        super().__init__()
        self.level = level
        self.instance_module = Transformer2DModel()
        self.background_module = Transformer2DModel()
        
    def forward(self, hidden_feature, ref_feature, masks,
                timestep: Optional[torch.LongTensor] = None,
                added_cond_kwargs: Dict[str, torch.Tensor] = None,):
        pdb.set_trace()
        new_hidden_feature = None
        for index in range(len(masks['instance'])):
            mask = masks['instance'][index]
            mask = F.interpolate(x, scale_factor=0.5**self.level)
            feature = mask*self.instance_module(hidden_states=hidden_feature, encoder_hidden_states=ref_feature[index],
                                                          attention_mask=mask, timestep=timestep, added_cond_kwargs=added_cond_kwargs,)
            if new_hidden_feature is None: new_hidden_feature = feature   
            else: new_hidden_feature += feature   
        new_hidden_feature += masks['bg']*self.background_module(hidden_states=hidden_feature, attention_mask=masks['bg'], timestep=timestep, added_cond_kwargs=added_cond_kwargs,)
        return new_hidden_feature
    
if __name__ == '__main__':
    model = PersonalizedAttention(level=0)
    hidden_feature = torch.zeros((2, 16, 320, 40, 24))
    ref_feature = [torch.zeros((2, 16, 320, 40, 24)),
                   torch.zeros((2, 16, 320, 40, 24))]
    masks = [torch.zeros((2, 16, 320, 40, 24)),
            torch.zeros((2, 16, 320, 40, 24))]
    new_hidden_feature = model.forward(hidden_feature, ref_feature, masks)
        
        