import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import MLP, GINEConv,GATConv,GCNConv,global_mean_pool,global_add_pool,global_max_pool,GlobalAttention,Set2Set
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros
from torch import nn
from torch.distributions.normal import Normal

num_atom_type = 120  # including the extra mask tokens
num_chirality_tag = 3

num_bond_type = 6  # including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 3

class GINConv(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """

    def __init__(self, emb_dim, aggr="add"):
        super(GINConv, self).__init__(aggr)
        
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(),
                                       torch.nn.Linear(2*emb_dim, emb_dim))
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

        self.aggr_x = None
        self.modify = -1
        self.x = None
        self.gating = 0
        self.gating_m = 0
        self.prompt = None

    def forward(self, x, edge_index, edge_attr):
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))

        # add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:, 0] = 4  # bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)

        edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])

        self.x = x
        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings), self.aggr_x

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):

        self.aggr_x = aggr_out
        if self.modify == 1:
            aggr_out = self.modify_aggr_out(aggr_out, aggr_out)
        elif self.modify == 0:
            aggr_out = self.modify_aggr_out(aggr_out, self.x)
        return self.mlp(aggr_out)

    def modify_aggr_out(self, aggr_out, delta):
        return aggr_out * (1 - self.gating_m) + self.prompt(delta) * self.gating

    def set_prompt(self, prompt, gating, gating_m):
        self.prompt = prompt
        self.gating = gating
        self.gating_m = gating_m
        
        
class GNN_node(torch.nn.Module):
    def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0, gnn_type="gin"):
        super(GNN_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
        self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)

        torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        # List of MLPs
        self.gnns = torch.nn.ModuleList()
        for layer in range(self.num_layer):
            if gnn_type == "gin":
                self.gnns.append(GINConv(emb_dim, aggr="add"))
            elif gnn_type == "gcn":
                self.gnns.append(GCNConv(emb_dim))
            elif gnn_type == "gat":
                self.gnns.append(GATConv(emb_dim))
            elif gnn_type == "graphsage":
                self.gnns.append(GraphSAGEConv(emb_dim))

        bottleneck_dim = 15
        prompt_num = 2

        gating = 0.01
        self.gating_parameter = torch.nn.Parameter(torch.zeros(prompt_num, num_layer, 1))
        self.gating_parameter.data += gating
        self.register_parameter('gating_parameter', self.gating_parameter)
        self.gating = self.gating_parameter

        # PEFT adapter 
        

        self.batch_norms = torch.nn.ModuleList()
        self.prompts = torch.nn.ModuleList()
        for i in range(prompt_num):
            self.prompts.append(torch.nn.ModuleList())

        for layer in range(self.num_layer):
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
            for i in range(prompt_num):
                if bottleneck_dim>0:
                    self.prompts[i].append(torch.nn.Sequential(
                        torch.nn.Linear(emb_dim, bottleneck_dim),
                        torch.nn.ReLU(),
                        torch.nn.Linear(bottleneck_dim, emb_dim),
                        torch.nn.BatchNorm1d(emb_dim)
                    ))
                    torch.nn.init.zeros_(self.prompts[i][-1][2].weight.data)
                    torch.nn.init.zeros_(self.prompts[i][-1][2].bias.data)
                else:
                    self.prompts[i].append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, *argv):
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.x_embedding1(x[:, 0]) + self.x_embedding2(x[:, 1])
        
        h_list = [x]
        for layer in range(self.num_layer):
            # pdb.set_trace()
            h = h_list[layer]
            
            h_mlp, x_aggr = self.gnns[layer](h, edge_index, edge_attr)

            h = self.batch_norms[layer](h_mlp)

            delta = self.prompts[0][layer](h_list[layer])
            h = h + delta * self.gating[0][layer]
            delta = self.prompts[1][layer](x_aggr)
            h = h + delta * self.gating[1][layer]

            if layer < self.num_layer - 1:
                h = F.relu(h)
            h = F.dropout(h, self.drop_ratio, training=self.training)


            h_list.append(h)

        # Different implementations of Jk-concat
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim=1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]

        return node_representation
    
    
    
class GNN(torch.nn.Module):

    def __init__(self, emb_dim = 300,num_layer = 5, gnn_type ='gin',
                    virtual_node = False, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, gnn_type = gnn_type)


        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")
    
    def from_pretrained(self, model_file):
        self.gnn_node.load_state_dict(torch.load(model_file,map_location="cuda:0"), strict=False)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)

        return h_graph
    
    
class MoE(torch.nn.Module):
    def __init__(self,model_type,pre_trained_model,gate_type,input_size,num_experts,min_layers,device, noisy_gating=True, k=4 , coef=1e-3,dropout=0.5,gate_dropout=0.2,heads=4):
        super(MoE, self).__init__()
        self.noisy_gating = noisy_gating
        self.num_experts = num_experts
        self.k = k
        self.loss_coef = coef
        self.device=device
        self.gate_dropout=gate_dropout
        self.gate_type=gate_type
        # instantiate experts
        self.experts = torch.nn.ModuleList()
        
        self.x_embedding1 = torch.nn.Embedding(num_atom_type, input_size)
        self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, input_size)
        
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, input_size)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, input_size)
        
        torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)
        
        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

        # model        
        for i in range(min_layers,num_experts+min_layers):
            input_channel = input_size
            model = GNN(input_channel,i,model_type,drop_ratio=dropout)
            model.from_pretrained(pre_trained_model)
            self.experts.append(model)


        #gate
        if gate_type == 'liner':
            print("Gate Type:", gate_type)
            self.w_gate = torch.nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
        elif gate_type =='GIN':
            print("Gate Type:", gate_type)
            self.gate_model=nn.ModuleList()
            input_channel=input_size
            for _ in range(2):
                mlp = MLP([input_channel, num_experts, num_experts],num_layers=2)
                self.gate_model.append(GINEConv(nn=mlp, train_eps=False,edge_dim=input_size))
                input_channel = num_experts
        elif gate_type =='GCN':
            print("Gate Type:", gate_type)
            self.gate_model=nn.ModuleList()
            input_channel=input_size
            for _ in range(2):
                self.gate_model.append(GCNConv(input_channel,num_experts))
                input_channel = num_experts
        elif gate_type =='GAT':
            print("Gate Type:", gate_type)
            self.gate_model=nn.ModuleList()
            input_channel=input_size
            for _ in range(2):
                self.gate_model.append(GATConv(input_channel, num_experts, heads,edge_dim=input_size ,dropout=0.6))
                input_channel = num_experts *heads
            self.gate_model.append(MLP([input_channel,input_channel,num_experts],norm=None, dropout=0.5))
           
        self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)

        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(1)
        self.register_buffer("mean", torch.tensor([0.0]))
        self.register_buffer("std", torch.tensor([1.0]))
        assert(self.k <= self.num_experts)

    def _gates_to_load(self, gates):
        """Compute the true load per expert, given the gates.
        """
        return (gates > 0).sum(0)

    def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
        """Helper function to NoisyTopKGating.
        """
        batch = clean_values.size(0)
        m = noisy_top_values.size(1)
        top_values_flat = noisy_top_values.flatten()

        threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k
        threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
        is_in = torch.gt(noisy_values, threshold_if_in)
        threshold_positions_if_out = threshold_positions_if_in - 1
        threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
        # is each value currently in the top k.
        normal = Normal(self.mean, self.std)
        prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
        prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
        prob = torch.where(is_in, prob_if_in, prob_if_out)
        return prob

    def noisy_top_k_gating(self, x,batch,edge_index,edge_attr, train, noise_epsilon=1e-2):
        """Noisy top-k gating.
        """
        
        x = self.x_embedding1(x[:, 0]) + self.x_embedding2(x[:, 1])
        edge_attr = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])
        
        if self.gate_type == 'liner':
            clean_logits = x @ self.w_gate
            clean_logits = global_mean_pool(clean_logits,batch)
        elif self.gate_type == 'GAT':
            input_x=x
            for conv in self.gate_model[:-1]:
                input_x = conv(input_x,edge_index,edge_attr)
            input_x=global_mean_pool(input_x,batch)
            input_x = self.gate_model[-1](input_x)
            clean_logits = input_x
        elif self.gate_type == 'GCN':
            input_x=x
            for conv in self.gate_model:
                input_x = conv(input_x,edge_index).relu()
            clean_logits = input_x
            clean_logits=global_mean_pool(clean_logits,batch)
        elif self.gate_type == 'GIN':
            input_x=x
            for conv in self.gate_model:
                input_x = conv(input_x,edge_index,edge_attr).relu()
            clean_logits = input_x
            clean_logits=global_mean_pool(clean_logits,batch)
        
        x=global_mean_pool(x,batch)
        if self.noisy_gating and train:
            raw_noise_stddev = x @ self.w_noise
            noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
            noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
            logits = noisy_logits
        else:
            logits = clean_logits

        top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
        top_k_logits = top_logits[:, :self.k]
        top_k_indices = top_indices[:, :self.k]
        top_k_gates = self.softmax(top_k_logits)

        zeros = torch.zeros_like(logits, requires_grad=True)
        gates = zeros.scatter(1, top_k_indices,top_k_gates )

        if self.noisy_gating and self.k < self.num_experts and train:
            load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
        else:
            load = self._gates_to_load(gates)
        return gates,load

    def forward(self,batch_data):  
        gates,load = self.noisy_top_k_gating(batch_data.x,batch_data.batch,batch_data.edge_index,batch_data.edge_attr, self.training)
        expert_outputs = []
        for i in range(self.num_experts):
            input_x = batch_data.x
            output=self.experts[i](batch_data)
            expert_outputs.append(output)
        expert_outputs = torch.stack(expert_outputs, dim=1)
        y = gates.unsqueeze(dim=-1) * expert_outputs
        y = y.sum(dim=1)
        
        return y
    
class MoLE_GNN(torch.nn.Module):
# 600 for graphcl and 300 base
    def __init__(self, device,num_tasks,pre_trained_model="masking.pth", num_layer = 5, emb_dim = 300, topK=4,min_layers=2,coef=0.001,
                    gnn_type = 'gin',gate_type='GIN', drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):

        super(MoLE_GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling


        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        self.gnn_node_moe = MoE(gnn_type,pre_trained_model,gate_type,emb_dim,num_layer,min_layers,device,k=topK, dropout = drop_ratio)
            

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_graph = self.gnn_node_moe(batched_data)

        return self.graph_pred_linear(h_graph)