from torch_geometric.nn import GCNConv
from torch.nn import init
from torch_geometric.nn import GATConv

from torch_geometric.nn import SAGEConv
import torch.nn as nn
import torch.nn.functional as F
import torch

class MultiViewDNN(nn.Module):
    def __init__(self, input_dim=11, output_dim=80):
        super(MultiViewDNN, self).__init__()
        self.mlp1 = nn.Linear(input_dim, input_dim * 2 )
        self.mlp2 = nn.Linear(input_dim * 2, output_dim)


    def forward(self, x):
        x = self.mlp1(x)
        x =torch.relu(x)
        x = self.mlp2(x)
        return x

class MV_GCN_Model(nn.Module):
    def __init__(self, in_channels, layer_num=1):
        super(MV_GCN_Model, self).__init__()

        self.in_channels = in_channels
        hidden_channels = in_channels
        self.multi_view_dnns = nn.ModuleList([
            MultiViewDNN(input_dim=11, output_dim=80) for _ in range(11)
        ])
        self.gcn1 = GCNConv(in_channels, hidden_channels) 
        self.gcn2 = GCNConv(hidden_channels, in_channels) 

        self.mlp = nn.Sequential(
            nn.Linear(4 * in_channels, 2 * hidden_channels),
            nn.ReLU(),
            nn.Linear(2 * hidden_channels, 1)
        )

    def encode_global(self, x):
        batch_size = x.size(0)

        x_splits = x.view(batch_size, 11, 11) 

        encoded_outputs = []
        for i in range(11):
            encoded_outputs.append(self.multi_view_dnns[i](x_splits[:, i, :]))

        encoded_outputs = torch.stack(encoded_outputs, dim=1)

        return encoded_outputs.mean(dim=1)

    def encode_local(self, x, edge_index):
        x = torch.relu(self.gcn1(x, edge_index))
        x = self.gcn2(x, edge_index)
        return x

    def decode(self, z1, z2, edge_label_index):
        src1 = z1[edge_label_index[0]]
        src2 = z2[edge_label_index[0]]
        dst1 = z1[edge_label_index[1]]
        dst2 = z2[edge_label_index[1]]
        combined = torch.cat([src1, src2, dst1, dst2], dim=-1)  

        return self.mlp(combined)

    def forward(self, local_embedding, global_embedding, edge_index, edge_label_index):
        z_global = self.encode_global(global_embedding)
        z_local = self.encode_local(local_embedding, edge_index)

        return self.decode(z_global, z_local, edge_label_index)


class MV_GCN_Fusion_Model(nn.Module):
    def __init__(self, in_channels, layer_num=1):
        super(MV_GCN_Fusion_Model, self).__init__()

        self.in_channels = in_channels
        hidden_channels = in_channels

        self.multi_view_dnns = nn.ModuleList([
            MultiViewDNN(input_dim=11, output_dim=80) for _ in range(11)
        ])

        self.gcn1 = GCNConv(in_channels, hidden_channels) 
        self.gcn2 = GCNConv(hidden_channels, in_channels) 

        self.mlp_local = nn.Sequential(
            nn.Linear(in_channels*2, 1)
        )

        self.mlp_global = nn.Sequential(
            nn.Linear(160, 1)
        )

        # Fusion Layer
        self.fusion_layer = nn.Linear(2, 1)  # Combine both local and global predictions

    def encode_global(self, x):
        batch_size = x.size(0)

        x_splits = x.view(batch_size, 11, 11)  

        encoded_outputs = []
        for i in range(11):
            encoded_outputs.append(self.multi_view_dnns[i](x_splits[:, i, :]))

        encoded_outputs = torch.stack(encoded_outputs, dim=1)

        return encoded_outputs.mean(dim=1)

    def encode_local(self, x, edge_index):
        x = torch.relu(self.gcn1(x, edge_index)) 
        x = self.gcn2(x, edge_index)
        return x

    def decode_local(self, z_local,edge_label_index):
        z_local_src = z_local[edge_label_index[0]]
        z_local_dst = z_local[edge_label_index[1]]
        combined = torch.cat([z_local_src,z_local_dst], dim=-1)
        return self.mlp_local(combined)

    def decode_global(self, z_global, edge_label_index):
        z_global_src = z_global[edge_label_index[0]]
        z_global_dst = z_global[edge_label_index[1]]
        combined = torch.cat([z_global_src, z_global_dst], dim=-1)
        return self.mlp_global(combined)

    def fusion(self, local_pred, global_pred):
        """Fusion Layer: Combine local and global predictions"""
        combined = torch.cat([local_pred, global_pred], dim=-1)
        return self.fusion_layer(combined)

    def forward(self, local_embedding, global_embedding, edge_index, edge_label_index):
        z_global = self.encode_global(global_embedding)
        z_local = self.encode_local(local_embedding, edge_index)

        local_pred = self.decode_local(z_local,edge_label_index)
        global_pred = self.decode_global(z_global,edge_label_index)

        fused_pred = self.fusion(local_pred, global_pred)

        return fused_pred


class MV_GCN_Tuning_Model(nn.Module):
    def __init__(self, in_channels, layer_num=1):
        super(MV_GCN_Tuning_Model, self).__init__()

        self.in_channels = in_channels
        hidden_channels = in_channels
        self.multi_view_dnns = nn.ModuleList([
            MultiViewDNN(input_dim=11, output_dim=80) for _ in range(11)
        ])
        self.gcn1 = GCNConv(in_channels, hidden_channels) 
        self.gcn2 = GCNConv(hidden_channels, in_channels)  

        self.mlp_local = nn.Sequential(
            nn.Linear(in_channels * 2 , in_channels * 2),
            nn.ReLU(),
            nn.Linear(in_channels * 2, 1)
        )

        self.mlp_fusion = nn.Sequential(
            nn.Linear(161, 322),
            nn.ReLU(),
            nn.Linear(322, 1) 
        )

    def encode_global(self, x):
        batch_size = x.size(0)

        x_splits = x.view(batch_size, 11, 11) 

        encoded_outputs = []
        for i in range(11):

            encoded_outputs.append(self.multi_view_dnns[i](x_splits[:, i, :]))

        encoded_outputs = torch.stack(encoded_outputs, dim=1)

        return encoded_outputs.mean(dim=1)

    def encode_local(self, x, edge_index):

        x = torch.relu(self.gcn1(x, edge_index)) 
        x = self.gcn2(x, edge_index) 
        return x

    def decode_local(self, z_local, edge_label_index):
        z_local_src = z_local[edge_label_index[0]]
        z_local_dst = z_local[edge_label_index[1]]

        return self.mlp_local(torch.cat([z_local_src, z_local_dst], dim=-1))

    def forward(self, local_embedding, global_embedding, edge_index, edge_label_index):
        z_global = self.encode_global(global_embedding)
        z_local = self.encode_local(local_embedding, edge_index)

        local_pred = self.decode_local(z_local, edge_label_index)

        z_global_src = z_global[edge_label_index[0]]
        z_global_dst = z_global[edge_label_index[1]]
        combined = torch.cat([z_global_src, z_global_dst], dim=-1)
        combined = torch.cat([local_pred, combined], dim=-1)

        fused_pred = self.mlp_fusion(combined)

        return fused_pred


class MV_GCN_Embedding_Model(nn.Module):
    def __init__(self, in_channels, layer_num=1):
        super(MV_GCN_Embedding_Model, self).__init__()

        self.in_channels = in_channels
        self.hidden_channels = in_channels * 2
        hidden_channels = in_channels * 2
        self.multi_view_dnns = nn.ModuleList([
            MultiViewDNN(input_dim=11, output_dim=80) for _ in range(11)
        ])
        self.gcn1 = GCNConv(in_channels, hidden_channels)  
        self.gcn2 = GCNConv(hidden_channels, in_channels)  

        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels, in_channels),
            nn.ReLU(),
            nn.Linear(in_channels, in_channels)
        )

        self.predicator = nn.Sequential(
            nn.Linear(2 * in_channels, in_channels),
            nn.ReLU(),
            nn.Linear(in_channels, 1)
        )

    def encode_global(self, x):
        batch_size = x.size(0)

        x_splits = x.view(batch_size, 11, 11) 
        encoded_outputs = []
        for i in range(11):
            encoded_outputs.append(self.multi_view_dnns[i](x_splits[:, i, :]))

        encoded_outputs = torch.stack(encoded_outputs, dim=1)

        return encoded_outputs.mean(dim=1)

    def encode_local(self, x, edge_index):
        x = torch.relu(self.gcn1(x, edge_index))
        x = self.gcn2(x, edge_index)
        return x

    def decode(self, z1, z2, edge_label_index):

        src_local = z1[edge_label_index[0]] 
        src_global = z2[edge_label_index[0]]  
    
        dst_local = z1[edge_label_index[1]]  
        dst_global = z2[edge_label_index[1]]  
  
        combined_src = torch.cat([src_local, src_global], dim=-1)
        combined_dst = torch.cat([dst_local, dst_global], dim=-1)
        combined_src = self.mlp(combined_src)
        combined_dst = self.mlp(combined_dst)
        combined = torch.cat([combined_src,combined_dst], dim=-1)
       
        return self.predicator(combined)

    def forward(self, local_embedding, global_embedding, edge_index, edge_label_index):

        def normalize_with_max_global(tensor):

            if tensor.numel() == 0: 
                return tensor
            max_val = tensor.max()
            if max_val == 0:  
                return tensor
            return tensor / max_val

      
        local_embedding = normalize_with_max_global(local_embedding)
        global_embedding = normalize_with_max_global(global_embedding)

       
        z_global = self.encode_global(global_embedding)
        z_local = self.encode_local(local_embedding, edge_index)

        
        return self.decode(z_global, z_local, edge_label_index)

class GCN_Model(nn.Module):
    def __init__(self, in_channels, layer_num=1):
        super(GCN_Model, self).__init__()

        self.in_channels = in_channels
        hidden_channels = in_channels
     
        self.gcn1 = GCNConv(in_channels, hidden_channels) 
        self.gcn2 = GCNConv(hidden_channels, in_channels)  

        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels,  hidden_channels),
            nn.ReLU(),
            nn.Linear( hidden_channels, 1)
        )

    def encode_local(self, x, edge_index):

        x = torch.relu(self.gcn1(x, edge_index))
        x = self.gcn2(x, edge_index)
        return x

    def decode(self, z1, edge_label_index):
        src1 = z1[edge_label_index[0]]

        dst1 = z1[edge_label_index[1]]

        combined = torch.cat([src1, dst1], dim=-1)  


        return self.mlp(combined)

    def forward(self, local_embedding, edge_index, edge_label_index):

        z_local = self.encode_local(local_embedding, edge_index)

        return self.decode(z_local, edge_label_index)


class MVDNN_Model(nn.Module):
    def __init__(self, in_channels, layer_num=1):
        super(MVDNN_Model, self).__init__()

        self.in_channels = in_channels
        hidden_channels = in_channels

        self.multi_view_dnns = nn.ModuleList([
            MultiViewDNN(input_dim=11, output_dim=8080) for _ in range(11)
        ])

        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels,  hidden_channels),
            nn.ReLU(),
            nn.Linear( hidden_channels, 1)
        )

    def encode_global(self, x):
        batch_size = x.size(0)

        x_splits = x.view(batch_size, 11, 11) 


        encoded_outputs = []
        for i in range(11):
            encoded_outputs.append(self.multi_view_dnns[i](x_splits[:, i, :]))

        encoded_outputs = torch.stack(encoded_outputs, dim=1)

        return encoded_outputs.mean(dim=1)


    def decode(self, z1, edge_label_index):

        src1 = z1[edge_label_index[0]]

        dst1 = z1[edge_label_index[1]]

        combined = torch.cat([src1, dst1], dim=-1)  
        return self.mlp(combined)

    def forward(self,  global_embedding, edge_label_index):
        z_global = self.encode_global(global_embedding)
        return self.decode(z_global, edge_label_index)