from torch import nn
from modules import (ResidualModuleWrapper, FeedForwardModule, GCNModule, SAGEModule, GATModule, GATSepModule,
                     TransformerAttentionModule, TransformerAttentionSepModule, SGCModule)
import copy


'''
To create a rewriting method we need to consider
1. Distance basis: Feature, Topological Infomation (Aggregated Feature, Trained Feature)
2. Topology Construction: Cosine Similarity, KNN
3. Fusion: New topo only, Combined with old topo
'''

MODULES = {
    'ResNet': [FeedForwardModule],
    'GCN': [GCNModule],
    'SGC': [SGCModule],
    'SAGE': [SAGEModule],
    'GAT': [GATModule],
    'GAT-sep': [GATSepModule],
    'GT': [TransformerAttentionModule, FeedForwardModule],
    'GT-sep': [TransformerAttentionSepModule, FeedForwardModule]
}


NORMALIZATION = {
    'None': nn.Identity,
    'LayerNorm': nn.LayerNorm,
    'BatchNorm': nn.BatchNorm1d
}


class Model(nn.Module):
    def __init__(self, model_name, num_layers, input_dim, hidden_dim, output_dim, hidden_dim_multiplier, num_heads, sgc_k,
                 normalization, residual, dropout, rewrite_fusion):

        super().__init__()

        normalization = NORMALIZATION[normalization]

        self.input_linear = nn.Linear(in_features=input_dim, out_features=hidden_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.act = nn.GELU()

        self.residual_modules = nn.ModuleList()
        for _ in range(num_layers):
            for module in MODULES[model_name]:
                residual_module = ResidualModuleWrapper(module=module,
                                                        normalization=normalization,
                                                        residual=residual,
                                                        dim=hidden_dim,
                                                        hidden_dim_multiplier=hidden_dim_multiplier,
                                                        num_heads=num_heads,
                                                        sgc_k=sgc_k,
                                                        dropout=dropout)

                self.residual_modules.append(residual_module)

        if rewrite_fusion=='both_seperate_param':
            self.residual_modules_new = self._get_clones(self.residual_modules)

        self.output_normalization = normalization(hidden_dim)
        self.output_linear = nn.Linear(in_features=hidden_dim, out_features=output_dim)

    def _get_clones(self, modules):
        return nn.ModuleList([copy.deepcopy(module) for module in modules])
    
    def forward(self, graph, x, graph_new=None, rewrite_fusion=None, rewrite_fusion_state=None):
        x = self.input_linear(x)
        x = self.dropout(x)
        x = self.act(x)

        if rewrite_fusion==None:
            for residual_module in self.residual_modules:
                x = residual_module(graph, x)
        elif rewrite_fusion=='both_share_param' and rewrite_fusion_state=='late':
            x_old, x_new = x, x
            for residual_module in self.residual_modules:
                x_old = residual_module(graph, x_old)
                x_new = residual_module(graph_new, x_new)
            x = x_old + x_new
        elif rewrite_fusion=='both_seperate_param' and rewrite_fusion_state=='late':
            x_old, x_new = x, x
            for i in range(len(self.residual_modules)):
                x_old = self.residual_modules[i](graph, x_old)
                x_new = self.residual_modules_new[i](graph_new, x_new)
            x = x_old + x_new
        elif rewrite_fusion=='both_share_param' and rewrite_fusion_state=='early':
            for residual_module in self.residual_modules:
                x = residual_module(graph, x) + residual_module(graph_new, x)
        elif rewrite_fusion=='both_seperate_param' and rewrite_fusion_state=='early':
            for i in range(len(self.residual_modules)):
                x = self.residual_modules[i](graph, x) + self.residual_modules_new[i](graph_new, x)

        x = self.output_normalization(x)
        x = self.output_linear(x).squeeze(1)

        return x
