import copy
import math
import dgl
import torch

from graph_learners import *
from layers import GCNConv_dense, GCNConv_dgl
from torch.nn import Sequential, Linear, ReLU


def SNN_Block(dim1, dim2, dropout=0.25):
    r"""
    Multilayer Reception Block w/ Self-Normalization (Linear + ELU + Alpha Dropout)

    args:
        dim1 (int): Dimension of input features
        dim2 (int): Dimension of output features
        dropout (float): Dropout rate
    """
    import torch.nn as nn

    return nn.Sequential(
            nn.Linear(dim1, dim2),
            nn.ELU(),
            nn.AlphaDropout(p=dropout, inplace=False))


class Adaptor(nn.Module):
    def __init__(self, omic_sizes):
        super(Adaptor, self).__init__()
        hidden = [768, 768]
        sig_networks = []
        for input_dim in omic_sizes:
            fc_omic = [SNN_Block(dim1=input_dim, dim2=hidden[0])]
            for i, _ in enumerate(hidden[1:]):
                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
            sig_networks.append(nn.Sequential(*fc_omic))
        self.sig_networks = nn.ModuleList(sig_networks) 

    def forward(self, x_omic):
        x_omic = [torch.stack([sig_feat[i] for sig_feat in x_omic]) for i in range(len(x_omic[0]))]
        h_omic = [self.sig_networks[idx].forward(sig_feat.float()) for idx, sig_feat in enumerate(x_omic)]
        x = torch.mean(torch.stack(h_omic, dim=0), dim=0)
        return x


# GCN for evaluation.
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, dropout_adj, sparse):
        super(GCN, self).__init__()
        self.layers_fused = nn.ModuleList()

        if sparse:
            self.layers_fused.append(GCNConv_dgl(in_channels, hidden_channels))
            for _ in range(num_layers):
                self.layers_fused.append(GCNConv_dgl(hidden_channels, hidden_channels))
        else:
            self.layers_fused.append(GCNConv_dense(in_channels, hidden_channels))
            for i in range(num_layers):
                self.layers_fused.append(GCNConv_dense(hidden_channels, hidden_channels))
            
        self.out_put= GCNConv_dense(hidden_channels, out_channels)
        self.dropout = dropout
        self.dropout_adj = nn.Dropout(p=dropout_adj)
        self.sparse = sparse

    def forward(self, x, Adj):
        Adj = Adj.detach()
        Adj.requires_grad = False
        if self.sparse:
            Adj.edata['w'] = self.dropout_adj(Adj.edata['w'])
        else:
            Adj = self.dropout_adj(Adj)
        for i, conv in enumerate(self.layers_fused):
            x = conv(x, Adj)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.out_put(x, Adj)
        return x



    
# GCN for evaluation.
class GCN_ft(nn.Module):
    def __init__(self, in_channels, hidden_channels, emb_dim, out_channels, num_layers, dropout, dropout_adj, sparse):
        super(GCN_ft, self).__init__()
        self.gnn_encoder_layers = nn.ModuleList()
        # self.layers_omic = nn.ModuleList()

        if sparse:
            self.gnn_encoder_layers.append(GCNConv_dgl(in_channels, hidden_channels))
            for _ in range(num_layers-2):
                self.gnn_encoder_layers.append(GCNConv_dgl(hidden_channels, hidden_channels))
        else:
            self.gnn_encoder_layers.append(GCNConv_dense(in_channels, hidden_channels))
            for i in range(num_layers-2):
                self.gnn_encoder_layers.append(GCNConv_dense(hidden_channels, hidden_channels))
            self.gnn_encoder_layers.append(GCNConv_dense(hidden_channels, emb_dim))
        self.out_put= GCNConv_dense(emb_dim, out_channels)
        self.dropout = dropout
        self.dropout_adj = nn.Dropout(p=dropout_adj)
        self.sparse = sparse

    def forward(self, x, Adj):
        if self.sparse:
            Adj.edata['w'] = self.dropout_adj(Adj.edata['w'])
        else:
            Adj = self.dropout_adj(Adj)
        path_x = x
        for i, conv in enumerate(self.gnn_encoder_layers):
            path_x = conv(path_x, Adj)
            path_x = F.relu(path_x)
            path_x = F.dropout(path_x, p=self.dropout, training=self.training)
        x = self.out_put(path_x, Adj)
        return x


class GraphEncoder(nn.Module):
    def __init__(self, nlayers, in_dim, hidden_dim, emb_dim, dropout, sparse):

        super(GraphEncoder, self).__init__()
        self.dropout = dropout
        self.gnn_encoder_layers = nn.ModuleList()
        self.act = nn.ReLU()

        if sparse:
            self.gnn_encoder_layers.append(GCNConv_dgl(in_dim, hidden_dim))
            for _ in range(nlayers - 2):
                self.gnn_encoder_layers.append(GCNConv_dgl(hidden_dim, hidden_dim))
            self.gnn_encoder_layers.append(GCNConv_dgl(hidden_dim, emb_dim))
        else:
            self.gnn_encoder_layers.append(GCNConv_dense(in_dim, hidden_dim))
            for _ in range(nlayers - 2):
                self.gnn_encoder_layers.append(GCNConv_dense(hidden_dim, hidden_dim))
            self.gnn_encoder_layers.append(GCNConv_dense(hidden_dim, emb_dim))
        self.sparse = sparse

    def forward(self, x, Adj):

        x = F.dropout(x, p=self.dropout, training=self.training)
        for conv in self.gnn_encoder_layers[:-1]:
            x = conv(x, Adj)
            x = self.act(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.gnn_encoder_layers[-1](x, Adj)
        return x


class GCL(nn.Module):
    def __init__(self, nlayers, in_dim, hidden_dim, emb_dim, proj_dim, dropout, sparse, num_g):
        super(GCL, self).__init__()

        self.num_g = num_g
        self.omic_encoder = GraphEncoder(nlayers, in_dim, hidden_dim, emb_dim, dropout, sparse)
        self.path_encoder = GraphEncoder(nlayers, in_dim, hidden_dim, emb_dim, dropout, sparse)
        self.fused_encoder = GraphEncoder(nlayers, in_dim * 2, hidden_dim, emb_dim, dropout, sparse)

        self.proj_s = nn.ModuleList([nn.Sequential(
            nn.Linear(emb_dim, proj_dim),
            nn.ReLU(inplace=True),
            nn.Linear(proj_dim, proj_dim)
        ) for _ in range(self.num_g)])

        self.proj_u = nn.ModuleList([nn.Sequential(
            nn.Linear(emb_dim, proj_dim),
            nn.ReLU(inplace=True),
            nn.Linear(proj_dim, proj_dim)
        ) for _ in range(self.num_g)])


        self.proj_f = nn.Sequential(
            nn.Linear(emb_dim, proj_dim),
            nn.ReLU(inplace=True),
            nn.Linear(proj_dim, proj_dim))


    def forward(self, x, Adj_, mode):
        if mode == 'path':
            embedding = self.path_encoder(x, Adj_)
        elif mode == 'omic':
            embedding = self.omic_encoder(x, Adj_)
        else:
            embedding = self.fused_encoder(x, Adj_)
        embedding = F.normalize(embedding, dim=1, p=2)
        return embedding

    def cal_loss(self, z_specific_adjs, z_aug_adjs, z_fused_adjs):

        batch_size, _ = z_specific_adjs[0].size()
        pos_eye = torch.eye(batch_size).to(z_fused_adjs[0].device)
        z_proj_s = [self.proj_s[i](z_specific_adjs[i]) for i in range(self.num_g)]
        z_proj_u = [self.proj_u[i](z_aug_adjs[i]) for i in range(self.num_g)]

        loss_smi = 0
        cnt = 0
        for i in range(self.num_g):
            for j in range(i+1, self.num_g):
                loss_smi += calc_lower_bound(z_proj_s[i], z_proj_s[j], pos_eye) 
                cnt += 1
        loss_smi = loss_smi / cnt

        loss_fused = 0
        loss_umi = 0
        loss_aug_fused = 0

        z_proj_fuse = self.proj_f(z_fused_adjs)
        for i in range(self.num_g):
            
            loss_fused += calc_lower_bound(z_proj_fuse, z_proj_s[i], pos_eye) 
            loss_umi += calc_lower_bound(z_proj_s[i], z_proj_u[i], pos_eye) 
            loss_aug_fused += calc_lower_bound(z_proj_fuse, z_proj_u[i], pos_eye) 

        loss_fused = loss_fused / self.num_g
        loss_umi = loss_umi / self.num_g
        loss_aug_fused = loss_aug_fused / self.num_g
        loss = loss_fused + loss_smi + loss_umi + loss_aug_fused
        return loss_fused, loss_smi, loss_umi, loss_aug_fused, loss




def AGG(h_list, adjs_o, nlayer, sparse=False):
    f_list = []
    for i in range(len(adjs_o)):
        z = h_list[i]
        adj = adjs_o[i]
        for i in range(nlayer):
            if sparse:
                z = torch.sparse.mm(adj, z)
            else:
                z = torch.matmul(adj, z)
        z = F.normalize(z, dim=1, p=2)
        f_list.append(z)

    return f_list



def sim_con(z1, z2, temperature):
    z1_norm = torch.norm(z1, dim=-1, keepdim=True)
    z2_norm = torch.norm(z2, dim=-1, keepdim=True)
    dot_numerator = torch.mm(z1, z2.t())
    dot_denominator = torch.mm(z1_norm, z2_norm.t()) + EOS
    sim_matrix = dot_numerator / dot_denominator / temperature
    return sim_matrix


def calc_lower_bound(z_1, z_2, pos, temperature=0.2):
    matrix_1 = torch.exp(sim_con(z_1, z_2, temperature))
    matrix_2 = matrix_1.t()

    matrix_1 = matrix_1 / (torch.sum(matrix_1, dim=1).view(-1, 1) + EOS)
    lori_1 = -torch.log(matrix_1.mul(pos).sum(dim=-1)).mean()

    matrix_2 = matrix_2 / (torch.sum(matrix_2, dim=1).view(-1, 1) + EOS)
    lori_2 = -torch.log(matrix_2.mul(pos).sum(dim=-1)).mean()

    return (lori_1 + lori_2) / 2


def calc_upper_bound(z_1, z_2, pos, temperature=0.2):
    matrix_1 = sim_con(z_1, z_2, temperature)
    loss = matrix_1.mul(pos).sum(dim=-1).mean() - matrix_1.mean()

    return loss


def sce_loss(x, y, beta=1):
    x = F.normalize(x, p=2, dim=-1)
    y = F.normalize(y, p=2, dim=-1)

    loss = (1 - (x * y).sum(dim=-1)).pow_(beta)

    loss = loss.mean()
    return loss


