import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from Models.TransMIL.nystrom_attention_catt import NystromAttention, Attention
# official code of transmil

class TransLayer_NystromAttention(nn.Module):
    def __init__(self, norm_layer=nn.LayerNorm, dim=512):
        super(TransLayer_NystromAttention, self).__init__()
        self.norm = norm_layer(dim)
        self.attn = NystromAttention(
            dim = dim,
            dim_head = dim//8,
            heads = 8,
            num_landmarks = dim//2, # number of landmarks
            pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
            residual = True,        # whether to do an extra residual with the value or not. Supposedly faster convergence if turned on
            dropout = 0.1
        )

    def forward(self, x, y):
        x = x + self.attn(self.norm(x), self.norm(y))
        return x

class TransLayer_Attention(nn.Module):
    def __init__(self, norm_layer=nn.LayerNorm, dim=512):
        super(TransLayer_Attention, self).__init__()
        self.norm = norm_layer(dim)
        self.attn = Attention(
            dim = dim,
            heads = 8,
        )
    
    def forward(self, x, y):
        x = x + self.attn(self.norm(x), self.norm(y))
        return x


class SelfCattLayer(nn.Module):
    def __init__(self, dim=512, att="nystrom_attention"):
        super(SelfCattLayer, self).__init__()
        self.is_att = TransLayer_NystromAttention(dim=dim)
        if att == "nystrom_attention":
            self.cs_att = TransLayer_NystromAttention(dim=dim)
        else:
            self.cs_att = TransLayer_Attention(dim=dim)

    def forward(self, bag_feats, conf_feats):
        bag_feats = self.is_att(bag_feats, bag_feats)
        conf_feats = self.cs_att(bag_feats, conf_feats)
        return bag_feats, conf_feats


class PPEG(nn.Module):
    def __init__(self, dim=512):
        super(PPEG, self).__init__()
        self.proj = nn.Conv2d(dim, dim, 7, 1, 7//2, groups=dim)
        self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim)
        self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim)

    def forward(self, x, H, W):
        B, _, C = x.shape
        cls_token, feat_token = x[:, 0], x[:, 1:]
        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
        x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
        return x


class TransMIL(nn.Module):
    def __init__(self, n_classes, input_size, confounder_path=None):
        super(TransMIL, self).__init__()
        self.n_classes = n_classes

        self._fc_is = nn.Sequential(nn.Linear(input_size, 512), nn.ReLU())
        self.cls_token_is = nn.Parameter(torch.randn(1, 1, 512))
        self.pos_layer_is = PPEG(dim=512)

        self._norm_is = nn.LayerNorm(512)
        self._norm_cs = nn.LayerNorm(512)
        self._fc_cat = nn.Linear(1024, self.n_classes)

        self._fc_cs = nn.Sequential(nn.Linear(input_size, 512), nn.ReLU())
        self.cls_token_cs = nn.Parameter(torch.randn(1, 1, 512))
        self.pos_layer_cs = PPEG(dim=512)

        self.self_catt_layer_1 = SelfCattLayer(dim=512, att="attention")
        self.self_catt_layer_2 = SelfCattLayer(dim=512, att="nystrom_attention")

        self.confounder_path = confounder_path
        conf_list = []
        for i in confounder_path:
            conf_list.append(torch.from_numpy(np.load(i)).float())
        conf_tensor = torch.cat(conf_list, 0)  #[ k, C, K] k-means, c classes , K-dimension, should concatenate at centers k
        conf_tensor_dim = conf_tensor.shape[-1]
        self.register_buffer("confounder_feat", conf_tensor)

    def forward(self, feats):
        # current bag
        bag_feats = feats.unsqueeze(0) # [batch_size, num_patches_bag, feat_dim], where batch_size = 1, feat_dim = 512
        bag_feats = self._fc_is(bag_feats)
        # ---->pad
        H = bag_feats.shape[1] # H = num_patches_bag
        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
        add_length = _H * _W - H
        bag_feats = torch.cat([bag_feats, bag_feats[:,:add_length,:]], dim = 1) # [batch_size, N, 512]  
        # ---->cls_token_is
        batch_size = bag_feats.shape[0] # batch_size = batch_size
        cls_tokens_is = self.cls_token_is.expand(batch_size, -1, -1).cuda() 
        bag_feats = torch.cat((cls_tokens_is, bag_feats), dim=1)
        # bag dictionary
        conf_feats = self.confounder_feat.unsqueeze(0) # [batch_size, num_patches_conf, feat_dim], where batch_size = 1, feat_dim = 512
        conf_feats = self._fc_cs(conf_feats) 
        # ---->(without padding)
        # ---->cls_token
        batch_size = conf_feats.shape[0] # batch_size = batch_size
        cls_tokens_cs = self.cls_token_cs.expand(batch_size, -1, -1).cuda()
        conf_feats = torch.cat((cls_tokens_cs, conf_feats), dim=1)
        # ---->self_catt_layer x1 
        bag_feats, conf_feats = self.self_catt_layer_1(bag_feats, conf_feats)
        # num_patches_bag should be equal to num_patches_conf after the first self-catt layer.
        # ---->PPEG (thinking shared weights or different weights)
        bag_feats = self.pos_layer_is(bag_feats, _H, _W)
        conf_feats = self.pos_layer_cs(conf_feats, _H, _W)
        # ---->self_catt_layer x2
        bag_feats, conf_feats = self.self_catt_layer_2(bag_feats, conf_feats)

        #---->cls_token
        h_not_norm = bag_feats[:, 0] # after the second self-catt layer according to the original code of TransMIL
        A = None

        cls_token_bag = self._norm_is(bag_feats)[:, 0]
        cls_token_conf = self._norm_cs(conf_feats)[:, 0]
        cat_feats = torch.cat((cls_token_bag, cls_token_conf), dim=1)

        #---->predict
        logits = self._fc_cat(cat_feats) # [batch_size, n_classes]
        Y_hat = torch.argmax(logits, dim=1)
        Y_prob = F.softmax(logits, dim = 1)
        results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat, "Bag_feature": cat_feats, "A": A, 'h_not_norm': h_not_norm}
        
        return results_dict


if __name__ == "__main__":
    data = torch.randn((1, 2, 512)).cuda()
    model = TransMIL(n_classes=2, input_size=512).cuda()
    print(model.eval())
    results_dict = model(feats = data)
    print(results_dict)