import torch
import torch.nn as nn
import torch.nn.functional as F
from architecture.network import Classifier_1fc, DimReduction



class AttentionLayer(nn.Module):
    def __init__(self, dim=512):
        super(AttentionLayer, self).__init__()
        self.dim = dim

    def forward(self, features, W_1, b_1):
        out_c = F.linear(features, W_1, b_1)
        out = out_c - out_c.max()
        out = out.exp()
        out = out.sum(1, keepdim=True)
        alpha = out / out.sum(0)

        alpha01 = features.size(0) * alpha.expand_as(features)
        context = torch.mul(features, alpha01)

        return context, out_c, torch.squeeze(alpha)

class LBMIL(nn.Module):
    def __init__(self, conf, droprate=0):
        super(LBMIL, self).__init__()
        self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
        self.attention = AttentionLayer(conf.D_inner)
        self.classifier = nn.Linear(conf.D_inner, conf.n_class)

    def forward(self, x): ## x: N x L
        x = x[0]
        med_feat = self.dimreduction(x)
        out, out_c, alpha = self.attention(med_feat, self.classifier.weight, self.classifier.bias)
        out = out.mean(0, keepdim=True)

        y = self.classifier(out)
        return y, out_c, alpha




