import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
print('multi-scale fusion scUniGP model loaded')

class scUniGP(nn.Module):
    def __init__(self, expression_data_shape, embed_size, num_layers, num_head,
                 tf_gnn_embedding_data, target_gnn_embedding_data,
                 l1_gnn_embedding_data, l2_gnn_embedding_data,
                 use_l1, use_l2, use_tf, use_target):
        super(scUniGP, self).__init__()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.use_l1 = use_l1
        self.use_l2 = use_l2
        self.use_tf = use_tf
        self.use_target = use_target
        self.embed_size = embed_size

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_head, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        self.position_embedding = nn.Embedding(2, embed_size)

        self.encoder512 = nn.Linear(expression_data_shape[1], 512)
        self.encoder768 = nn.Linear(512, embed_size)

        gnn_input_dim = 0
        if use_l1:
            self.l1_gnn_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(l1_gnn_embedding_data), freeze=False)
            self.l1_proj = nn.Linear(l1_gnn_embedding_data.shape[1], embed_size)
            gnn_input_dim += 2 * embed_size
        if use_l2:
            self.l2_gnn_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(l2_gnn_embedding_data), freeze=False)
            self.l2_proj = nn.Linear(l2_gnn_embedding_data.shape[1], embed_size)
            gnn_input_dim += 2 * embed_size
        if use_tf:
            self.tf_gnn_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(tf_gnn_embedding_data), freeze=False)
            self.tf_proj = nn.Linear(tf_gnn_embedding_data.shape[1], embed_size)
            gnn_input_dim += embed_size
        if use_target:
            self.target_gnn_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(target_gnn_embedding_data), freeze=False)
            self.target_proj = nn.Linear(target_gnn_embedding_data.shape[1], embed_size)
            gnn_input_dim += embed_size

        self.flatten = nn.Flatten()
        fusion_input_dim = embed_size * 2 + gnn_input_dim
        self.linear1024 = nn.Linear(fusion_input_dim, 1024)
        self.layernorm1024 = nn.LayerNorm(1024)
        self.batchnorm1024 = nn.BatchNorm1d(1024)

        self.linear512 = nn.Linear(1024, 512)
        self.layernorm512 = nn.LayerNorm(512)
        self.batchnorm512 = nn.BatchNorm1d(512)

        self.linear256 = nn.Linear(512 + 1, 256)
        self.layernorm256 = nn.LayerNorm(256)
        self.batchnorm256 = nn.BatchNorm1d(256)

        self.linear2 = nn.Linear(256, 1)
        self.actf = nn.PReLU()
        self.dropout = nn.Dropout(p=0.2)
        self.pool = nn.AvgPool1d(kernel_size=4, stride=4)

    def forward(self, main_pair, main_expr, related_exprs, gnn_score, return_penultimate=False):
        bs = main_expr.shape[0]

        position = torch.Tensor([0, 1] * bs).reshape(bs, -1).to(torch.int32).to(self.device)
        p_e = self.position_embedding(position)

        out_expr_e = self.encoder512(main_expr)
        out_expr_e = F.leaky_relu(self.encoder768(out_expr_e))

        transformer_input = out_expr_e + p_e
        transformer_output = self.transformer_encoder(transformer_input)
        transformer_output_flat = self.flatten(transformer_output)

        gnn_features = []
        if self.use_l1:
            l1_embed = self.l1_gnn_embedding(main_pair)
            l1_proj = self.l1_proj(l1_embed)
            gnn_features.append(l1_proj.reshape(bs, -1))
        if self.use_l2:
            l2_embed = self.l2_gnn_embedding(main_pair)
            l2_proj = self.l2_proj(l2_embed)
            gnn_features.append(l2_proj.reshape(bs, -1))
        if self.use_tf:
            tf_ids = main_pair[:, 0]
            tf_embed = self.tf_gnn_embedding(tf_ids)
            tf_proj = self.tf_proj(tf_embed)
            gnn_features.append(tf_proj)
        if self.use_target:
            target_ids = main_pair[:, 1]
            target_embed = self.target_gnn_embedding(target_ids)
            target_proj = self.target_proj(target_embed)
            gnn_features.append(target_proj)

        if gnn_features:
            gnn_concat = torch.cat(gnn_features, dim=1)
            fusion_vector = torch.cat([transformer_output_flat, gnn_concat], dim=1)
        else:
            fusion_vector = transformer_output_flat

        out = self.linear1024(fusion_vector)
        out = self.dropout(out)
        out = self.actf(out)

        r = out.unsqueeze(1)
        r = self.pool(r)
        r = r.squeeze(1)

        out = self.linear512(out)
        out = self.dropout(out)
        out = self.actf(out)

        gnn_score = gnn_score.unsqueeze(1).to(torch.float32)
        out = torch.cat([out, gnn_score], dim=1)

        penultimate = out.clone()

        out = self.linear256(out) + r
        out = self.dropout(out)
        out = self.actf(out)

        outs = self.linear2(out)
        outs = nn.Sigmoid()(outs)

        if return_penultimate:
            return penultimate, outs
        else:
            return outs

