import torch
from torch import linalg as LA
import torch.nn.functional as F
import torch.nn as nn
from typing import *
from mm_lego.models.baselines.mcat import SNN_Block, BilinearFusion, Attn_Net_Gated

import ot

from mm_lego.models.model_utils import *


class OT_Attn_assem(nn.Module):
    def __init__(self, impl='pot-uot-l2', ot_reg=0.1, ot_tau=0.5) -> None:
        super().__init__()
        self.impl = impl
        self.ot_reg = ot_reg
        self.ot_tau = ot_tau
        print("ot impl: ", impl)

    def normalize_feature(self, x):
        x = x - x.min(-1)[0].unsqueeze(-1)
        return x

    def OT(self, weight1, weight2):
        """
        Parmas:
            weight1 : (N, D)
            weight2 : (M, D)

        Return:
            flow : (N, M)
            dist : (1, )
        """

        if self.impl == "pot-sinkhorn-l2":
            self.cost_map = torch.cdist(weight1, weight2) ** 2  # (N, M)

            src_weight = weight1.sum(dim=1) / weight1.sum()
            dst_weight = weight2.sum(dim=1) / weight2.sum()

            cost_map_detach = self.cost_map.detach()
            flow = ot.sinkhorn(a=src_weight.detach(), b=dst_weight.detach(),
                               M=cost_map_detach / cost_map_detach.max(), reg=self.ot_reg)
            dist = self.cost_map * flow
            dist = torch.sum(dist)
            return flow, dist

        elif self.impl == "pot-uot-l2":
            a, b = ot.unif(weight1.size()[0]).astype('float64'), ot.unif(weight2.size()[0]).astype('float64')
            self.cost_map = torch.cdist(weight1, weight2) ** 2  # (N, M)

            cost_map_detach = self.cost_map.detach()
            M_cost = cost_map_detach / cost_map_detach.max()

            flow = ot.unbalanced.sinkhorn_knopp_unbalanced(a=a, b=b,
                                                           M=M_cost.double().cpu().numpy(), reg=self.ot_reg,
                                                           reg_m=self.ot_tau)
            flow = torch.from_numpy(flow).type(torch.FloatTensor).cuda()

            dist = self.cost_map * flow  # (N, M)
            dist = torch.sum(dist)  # (1,) float
            return flow, dist

        else:
            raise NotImplementedError

    def forward(self, x, y):
        '''
        x: (N, 1, D)
        y: (M, 1, D)
        '''
        x = x.squeeze()
        y = y.squeeze()

        x = self.normalize_feature(x)
        y = self.normalize_feature(y)

        pi, dist = self.OT(x, y)
        return pi.T.unsqueeze(0).unsqueeze(0), dist


#############################
### MOTCAT Implementation ###
#############################
class MOTCAT(nn.Module):
    def __init__(self, omic_shape: Tuple, wsi_shape: Tuple, fusion='concat', omic_sizes=[100, 200, 300, 400, 500, 600], n_classes=4,
                 model_size_wsi: str = 'small', model_size_omic: str = 'small', dropout=0.25, ot_reg=0.1, ot_tau=0.5,
                 ot_impl="pot-uot-l2"):
        super(MOTCAT, self).__init__()
        self.fusion = fusion
        self.omic_shape = omic_shape
        self.wsi_shape = wsi_shape
        self.n_classes = n_classes
        self.omic_sizes = omic_sizes
        self.n_classes = n_classes
        self.size_dict_WSI = {"small": [1024, 256, 256], "big": [1024, 512, 384]}
        self.size_dict_omic = {'small': [256, 256], 'big': [1024, 1024, 1024, 256]}

        ### FC Layer over WSI bag
        size = self.size_dict_WSI[model_size_wsi]
        size[0] = wsi_shape[1]
        # fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
        fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
        fc.append(nn.Dropout(dropout))
        self.wsi_net = nn.Sequential(*fc)

        ### Constructing Genomic SNN
        # hidden = self.size_dict_omic[model_size_omic]
        # sig_networks = []
        # for input_dim in omic_sizes:
        #     fc_omic = [SNN_Block(dim1=input_dim, dim2=hidden[0])]
        #     for i, _ in enumerate(hidden[1:]):
        #         fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i + 1], dropout=0.25))
        #     sig_networks.append(nn.Sequential(*fc_omic))
        # self.sig_networks = nn.ModuleList(sig_networks)
        hidden = self.size_dict_omic[model_size_omic]
        sig_networks = []
        for input_dim in self.omic_shape:
            fc_omic = [SNN_Block(dim1=input_dim, dim2=hidden[0])]
            for i, _ in enumerate(hidden[1:]):
                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i + 1], dropout=0.25))
            sig_networks.append(nn.Sequential(*fc_omic))
        self.sig_networks = nn.ModuleList(sig_networks)

        ### OT-based Co-attention
        self.coattn = OT_Attn_assem(impl=ot_impl, ot_reg=ot_reg, ot_tau=ot_tau)

        ### Path Transformer + Attention Head
        path_encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=dropout,
                                                        activation='relu', batch_first=True)
        self.path_transformer = nn.TransformerEncoder(path_encoder_layer, num_layers=2)
        self.path_attention_head = Attn_Net_Gated(L=size[2], D=size[2], dropout=dropout, n_classes=1)
        self.path_rho = nn.Sequential(*[nn.Linear(size[2], size[2]), nn.ReLU(), nn.Dropout(dropout)])

        ### Omic Transformer + Attention Head
        omic_encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=dropout,
                                                        activation='relu', batch_first=True)
        self.omic_transformer = nn.TransformerEncoder(omic_encoder_layer, num_layers=2)
        self.omic_attention_head = Attn_Net_Gated(L=size[2], D=size[2], dropout=dropout, n_classes=1)
        self.omic_rho = nn.Sequential(*[nn.Linear(size[2], size[2]), nn.ReLU(), nn.Dropout(dropout)])

        ### Fusion Layer
        if self.fusion == 'concat':
            self.mm = nn.Sequential(*[nn.Linear(256 * 2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
        elif self.fusion == 'bilinear':
            self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256)
        else:
            self.mm = None

        ### Classifier
        self.classifier = nn.Linear(size[2], n_classes)

    def forward(self, data, **kwargs):

        data[0] = data[0].squeeze()

        x_omic = [data[0]]  # needs to be wrapped in list (used to expect multiple omic sources)
        h_omic = [self.sig_networks[idx].forward(sig_feat) for idx, sig_feat in
                  enumerate(x_omic)]  ### each omic signature goes through it's own FC layer

        h_omic_bag = torch.stack(h_omic)  ### omic embeddings are stacked (to be used in co-attention)

        x_path = data[1]
        h_path_bag = self.wsi_net(x_path).unsqueeze(1)  ### path embeddings are fed through a FC layer

        # data[0] = data[0].squeeze()
        #
        # x_omic = [data[0]]
        # h_omic = [self.sig_networks[idx].forward(sig_feat) for idx, sig_feat in enumerate(x_omic)]
        # h_omic_bag = torch.stack(h_omic)
        #
        # x_path = data[1]
        #
        # h_path_bag = self.wsi_net(x_path).unsqueeze(1)  ### path embeddings are fed through a FC layer
        #

        ### Coattn
        A_coattn, _ = self.coattn(h_path_bag, h_omic_bag)
        h_path_coattn = torch.mm(A_coattn.squeeze(), h_path_bag.squeeze()).unsqueeze(1)

        ### Path
        h_path_trans = self.path_transformer(h_path_coattn)
        A_path, h_path = self.path_attention_head(h_path_trans.squeeze(1))
        A_path = torch.transpose(A_path, 1, 0)
        h_path = torch.mm(F.softmax(A_path, dim=1), h_path)
        h_path = self.path_rho(h_path).squeeze()

        ### Omic
        h_omic_trans = self.omic_transformer(h_omic_bag)
        A_omic, h_omic = self.omic_attention_head(h_omic_trans.squeeze(1))
        A_omic = torch.transpose(A_omic, 1, 0)
        h_omic = torch.mm(F.softmax(A_omic, dim=1), h_omic)
        h_omic = self.omic_rho(h_omic).squeeze()

        if self.fusion == 'bilinear':
            h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze()
        elif self.fusion == 'concat':
            h = self.mm(torch.cat([h_path, h_omic], axis=1))

        logits = self.classifier(h)
        return logits