import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle
import sys
import numpy as np

class FirstPMA(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=True, softmax_dim=1):
        super(FirstPMA, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.softmax_dim = softmax_dim
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
            self.ln2 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q, K = self.ln0(Q), self.ln1(K)
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        attention = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
        
        set_attention = attention[:2, :, :]
        slot_attention = attention[2:, :, :]

        A_set = torch.softmax(set_attention, 2) + 1e-8 # for stability
        A_set = A_set / A_set.sum(dim=-self.softmax_dim, keepdim=True) 

        A_slot = torch.softmax(slot_attention, self.softmax_dim) + 1e-8
        A_slot = A_slot / A_slot.sum(dim=-self.softmax_dim, keepdim=True) 

        A = torch.cat((A_set, A_slot), dim=0)
    
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O + F.relu(self.fc_o(self.ln2(O)))
        return O

class SecondPMA(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, ln=True, softmax_dim=1, **kwargs):
        super(SecondPMA, self).__init__()
        self.dim_K = dim_K
        self.softmax_dim = softmax_dim
        self.fc_q = nn.Linear(dim_Q, dim_K)
        self.fc_k = nn.Linear(dim_K, dim_K)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_K)
            self.ln1 = nn.LayerNorm(dim_K)

    def forward(self, Q, K):
        Q, K = self.ln0(Q), self.ln1(K)
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        A = torch.softmax(Q.bmm(K.transpose(1,2))/math.sqrt(self.dim_K), 2)        
        O = A.bmm(V)

        return O

class Slotmil(nn.Module):
    def __init__(self, dim, num_heads, num_slots, num_classes, ln=True):
        super(Slotmil, self).__init__()
        
        self.I1 = nn.Parameter(torch.Tensor(1, num_slots, dim))
        self.I2 = nn.Parameter(torch.Tensor(1, num_classes, dim))
        
        nn.init.xavier_uniform_(self.I1)
        nn.init.xavier_uniform_(self.I2)
        
        self.pma1 = FirstPMA(dim, dim, dim, num_heads)
        self.pma2 = SecondPMA(dim, dim, dim_V=1)

    def forward(self, patches, **kwargs):

        slot = self.pma1(self.I1, patches)
        class_logit = self.pma2(self.I2, slot) 

        return class_logit
    
    
if __name__ == "__main__":
    model = Slotmil(dim=512, num_heads=4, num_slots=16, num_classes=2)

    bag = torch.randn((1, 1612, 512))
    logits = model(bag)