import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from architecture.network import Classifier_1fc, DimReduction

class Attention_Gated(nn.Module):
    def __init__(self, L=512, D=128, K=1):
        super(Attention_Gated, self).__init__()

        self.L = L
        self.D = D
        self.K = K

        self.attention_V = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh()
        )

        self.attention_U = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Sigmoid()
        )

        self.attention_weights = nn.Linear(self.D, self.K)

    def forward(self, x):
        ## x: N x L
        A_V = self.attention_V(x)  # NxD
        A_U = self.attention_U(x)  # NxD
        A = self.attention_weights(A_V * A_U) # NxK
        A = torch.transpose(A, 1, 0)  # KxN


        return A  ### K x N


class IBMIL(nn.Module):
    def __init__(self, conf, confounder_dim=128, confounder_merge='cat'):
        super(IBMIL, self).__init__()
        self.confounder_merge = confounder_merge
        assert confounder_merge in ['cat', 'add', 'sub']
        self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
        self.attention = Attention_Gated(conf.D_inner, 128, 1)
        self.classifier = Classifier_1fc(conf.D_inner, conf.n_class, 0)
        self.confounder_path = None
        if conf.c_path:
            print('deconfounding')
            self.confounder_path = conf.c_path
            conf_list = []
            for i in conf.c_path:
                conf_list.append(torch.from_numpy(np.load(i)).view(-1, conf.D_inner).float())
            conf_tensor = torch.cat(conf_list, 0)
            conf_tensor_dim = conf_tensor.shape[-1]
            if conf.c_learn:
                self.confounder_feat = nn.Parameter(conf_tensor, requires_grad=True)
            else:
                self.register_buffer("confounder_feat", conf_tensor)
            joint_space_dim = confounder_dim
            dropout_v = 0.5
            self.W_q = nn.Linear(conf.D_inner, joint_space_dim)
            self.W_k = nn.Linear(conf_tensor_dim, joint_space_dim)
            if confounder_merge == 'cat':
                self.classifier = nn.Linear(conf.D_inner + conf_tensor_dim, conf.n_class)
            elif confounder_merge == 'add' or 'sub':
                self.classifier = nn.Linear(conf.D_inner, conf.n_class)
            self.dropout = nn.Dropout(dropout_v)

    def forward(self, x):
        x = x[0]
        x = self.dimreduction(x)
        A = self.attention(x)  ## K x N
        A = F.softmax(A, dim=1)  # softmax over N
        M = torch.mm(A, x) ## K x L
        # x = x.squeeze(0)

        # H = self.feature_extractor_part1(x)
        # H = H.view(-1, 50 * 4 * 4)
        # H = self.feature_extractor_part2(H)  # NxL

        # A = self.attention_1(x)
        # A = self.attention_2(A)  # NxK
        # A = self.attention(x)  # NxK
        # A = torch.transpose(A, 1, 0)  # KxN
        # A = F.softmax(A, dim=1)  # softmax over N
        # print('norm')
        # A = F.softmax(A/ torch.sqrt(torch.tensor(x.shape[1])), dim=1)  # For Vis

        # M = torch.mm(A, x)  # KxL
        if self.confounder_path:
            device = M.device
            # bag_q = self.confounder_W_q(M)
            # conf_k = self.confounder_W_k(self.confounder_feat)
            bag_q = self.W_q(M)
            conf_k = self.W_k(self.confounder_feat)
            deconf_A = torch.mm(conf_k, bag_q.transpose(0, 1))
            deconf_A = F.softmax(
                deconf_A / torch.sqrt(torch.tensor(conf_k.shape[1], dtype=torch.float32, device=device)),
                0)  # normalize attention scores, A in shape N x C,
            conf_feats = torch.mm(deconf_A.transpose(0, 1),
                                  self.confounder_feat)  # compute bag representation, B in shape C x V
            if self.confounder_merge == 'cat':
                M = torch.cat((M, conf_feats), dim=1)
            elif self.confounder_merge == 'add':
                M = M + conf_feats
            elif self.confounder_merge == 'sub':
                M = M - conf_feats
        Y_prob = self.classifier(M)
        # Y_hat = torch.ge(Y_prob, 0.5).float()
        if self.confounder_path:
            return Y_prob, M, deconf_A
        else:
            return Y_prob, M, A

    # # AUXILIARY METHODS
    # def calculate_classification_error(self, X, Y):
    #     Y = Y.float()
    #     _, Y_hat, _ = self.forward(X)
    #     error = 1. - Y_hat.eq(Y).cpu().float().mean().data.item()
    #
    #     return error, Y_hat
    #
    # def calculate_objective(self, X, Y):
    #     Y = Y.float()
    #     Y_prob, _, A = self.forward(X)
    #     Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
    #     neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli
    #
    #     return neg_log_likelihood, A

