import torch
import torch.nn.functional as F
from torch_geometric.data import Batch, Data
from torch_geometric.nn import GINConv, global_mean_pool,global_max_pool, GINEConv, JumpingKnowledge
from torch_geometric.utils import to_undirected
# from prompt_graph.utils import act
# from deprecated.sphinx import deprecated
from sklearn.cluster import KMeans
from torch_geometric.nn.inits import glorot
import torch
from termcolor import colored
import math
import torch.nn as nn
import torch.nn.functional as F
import random
import pdb
from tqdm import tqdm
class InactivePrompt(torch.nn.Module):
    def __init__(self, in_len=2000, hidden=13, gin_num_layer = 5, num_token = 10, use_jk = True, use_GRU = True,method = 'xavier', device=None):
    
        super(InactivePrompt, self).__init__()
        self.device = device
        self.hidden = hidden
        print(self.hidden)
        ## protein embedding module
        self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=3, padding=0)
        self.bn1 = nn.BatchNorm1d(1)
        self.biGRU = nn.GRU(1, 1, bidirectional=True, batch_first=True, num_layers=1)
        self.maxpool1d = nn.MaxPool1d(3, stride=3)
        self.global_avgpool1d = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(math.floor(in_len / 3), self.hidden)

        self.num_token = num_token
        self.A_B_token = torch.nn.Parameter(torch.empty(num_token, self.hidden))


        if method == 'kaiming':
            torch.nn.init.kaiming_uniform_(self.A_B_token, nonlinearity='relu')

        if method == 'normal':
            torch.nn.init.normal_(self.A_B_token, mean=0.0, std=0.02)

        if method == 'zeros':
            torch.nn.init.zeros_(self.A_B_token)
        if method == 'xavier':
            torch.nn.init.xavier_uniform_(self.A_B_token)

        # torch.nn.init.kaiming_uniform_(self.A_B_token, nonlinearity='leaky_relu', mode='fan_in', a=0.01)
        self.use_jk = use_jk
        self.use_GRU = use_GRU
        self.jump = JumpingKnowledge('cat')
        if use_jk:
            self.lin1 = nn.Linear(gin_num_layer*self.hidden, self.hidden)
        else:
            self.lin1 = nn.Linear(self.hidden, self.hidden)
        self.gin_model = GINConv( 
            nn.Sequential(
                nn.Linear(self.hidden, self.hidden),
                nn.ReLU(),
                nn.BatchNorm1d(self.hidden),
            ), eps=0.2, train_eps=True, edge_dim=1, add_self_loops=True
        )
        self.gin_model_head_1st = GINEConv( 
            nn.Sequential(
                nn.Linear(self.hidden, self.hidden),
                nn.ReLU(),
                nn.Linear(self.hidden, self.hidden),
                nn.ReLU(),
                nn.BatchNorm1d(self.hidden),
            ), eps=0.2, train_eps=True, edge_dim=1, add_self_loops=True
        )
        self.gin_model_head = torch.nn.ModuleList()
        for i in range(gin_num_layer - 1):
            self.gin_model_head.append(
                GINEConv(
                    nn.Sequential(
                        nn.Linear(self.hidden, self.hidden),
                        nn.ReLU(),
                        nn.Linear(self.hidden, self.hidden),
                        nn.ReLU(),
                        nn.BatchNorm1d(self.hidden),
                    ), eps=0.2, train_eps=True, edge_dim=1, add_self_loops=True
                )
            )

        self.gate1 = nn.Linear(self.hidden, self.hidden)
        self.gate2 = nn.Linear(self.hidden, 1)
        self.head = torch.nn.Linear(hidden, 7)

    def freeze_param_token(self):
        self.A_B_token.requires_grad = False
    
    def freeze_param_gin(self):


        for param in self.gin_model.parameters():
            param.requires_grad = False
        for gin_head in self.gin_model_head:
            for param in gin_head.parameters():
                param.requires_grad = False
        for param in self.gin_model_head_1st.parameters():
            param.requires_grad = False

        for param in self.gate1.parameters():
            param.requires_grad = True
        for param in self.gate2.parameters():
            param.requires_grad = True

        for param in self.head.parameters():
            param.requires_grad = True

    def trainable_param_gin(self):

        for param in self.gin_model.parameters():
            param.requires_grad = True
        for gin_head in self.gin_model_head:
            for param in gin_head.parameters():
                param.requires_grad = True
        for param in self.gin_model_head_1st.parameters():
            param.requires_grad = True

    def build_large_graph(self, x1, x2, token_emb):
        node_features = torch.cat([torch.stack([x1, x2], dim=0), token_emb], dim=0)    
        edge_list = []
        a_idx, b_idx = 0, 1
        for i in range(int(token_emb.shape[0] / 2)):
            x_idx = 2 + i * 2
            y_idx = 3 + i * 2
            

            edge_list.append([a_idx, x_idx])  # a->x_i
            edge_list.append([x_idx, y_idx])  # x_i->y_i
            edge_list.append([y_idx, b_idx])  # y_i->b

        edge_index = torch.tensor(edge_list, dtype=torch.long, device = x1.device).t()
        undirected_edge_index = to_undirected(edge_index)
        
        return Data(x=node_features, edge_index=undirected_edge_index)

    def reset_parameters(self):
        
        self.conv1d.reset_parameters()
        self.fc1.reset_parameters()
        self.gin_model.reset_parameters()
        self.head.reset_parameters()

    def diff(self, logits, probabilities):
        if self.training:

            uniforms = torch.rand_like(probabilities)
            gumbels = -torch.log(-torch.log(uniforms + 1e-10) + 1e-10)
            samples = torch.sigmoid((logits + gumbels) / 0.1)
            decisions = (samples > 0.5).float()
        else:

            decisions = (probabilities > 0.5).float()

        return probabilities, decisions

    def sample_subgraphs(self, original_data):
        subgraphs = []
        x = original_data.x
        edge_index = original_data.edge_index
        
        
        for i in range(int((original_data.x.shape[0]-2)/2)):

            node_indices = torch.tensor([0, 1, 2 + 2*i, 3 + 2*i], dtype=torch.long)

            node_mask = torch.zeros(x.size(0), dtype=torch.bool)
            node_mask[node_indices] = True

            src_nodes = edge_index[0]
            dst_nodes = edge_index[1]
            edge_mask = node_mask[src_nodes] & node_mask[dst_nodes]
            subgraph_edge_index = edge_index[:, edge_mask]

            mapping = {idx.item(): local_idx for local_idx, idx in enumerate(node_indices)}
            
            subgraph_edge_index = subgraph_edge_index.to(x.device)
            subgraph = Data(
                x=x[node_indices],
                edge_index=subgraph_edge_index,

                num_nodes = 4
            )
            # pdb.set_trace()
            subgraphs.append(subgraph)
        
        return subgraphs


    def sample_subgraphs_rep(self, original_data):
        subgraphs = []
        x = original_data.x
        edge_index = original_data.edge_index
        for i in range(int((original_data.x.shape[0]-2)/2)):

            node_indices = torch.tensor([0, 1, 2 + 2*i, 3 + 2*i], dtype=torch.long)

            subgraph = Data(
                x=x[node_indices],
                edge_index=torch.tensor([[0, 1, 2, 2, 3, 3],[2, 3, 0, 3, 1, 2]], device=x.device),

                num_nodes = 4
            )
            # pdb.set_trace()
            subgraphs.append(subgraph)
        
        return subgraphs

    def differentiable_edge_gating(self, data, gate_logits, temperature=0.1):
        # pdb.set_trace()
        edge_src = data.edge_index[0]
        edge_batch = data.batch[edge_src]
        edge_gates = gate_logits.squeeze(1)[edge_batch]
        u = torch.rand_like(edge_gates)
        gumbel_noise = -torch.log(-torch.log(u + 1e-10) + 1e-10)
        gumbel_logits = torch.log(edge_gates + 1e-10) - torch.log(1 - edge_gates + 1e-10) + gumbel_noise
        edge_probs = torch.sigmoid(gumbel_logits / temperature)
        hard_decisions = (edge_gates > 0.5).float()
        edge_weights = hard_decisions - edge_probs.detach() + edge_probs
        data.edge_weight = edge_weights.unsqueeze(1)
        return data

    def forward(self, x, edge_index, train_edge_id, active_prune, gate_work):
        graph_list_all, x_list, large_graph_list = [],[],[]
        x = x.transpose(1, 2)
        x = self.conv1d(x)
        x = self.bn1(x)
        x = self.maxpool1d(x)
        if self.use_GRU:
            x = x.transpose(1, 2)
            x, _ = self.biGRU(x)
            x = self.global_avgpool1d(x)
        x = x.squeeze()
        x = self.fc1(x)
        x = F.dropout(x, p=0.3, training=self.training)
        node_id = edge_index[:, train_edge_id]
        
        for i in range(node_id.shape[1]):
            x1 = x[node_id[0, i]]
            x2 = x[node_id[1, i]]
            single_large_graph = self.build_large_graph(x1, x2, self.A_B_token)
            large_graph_list.append(single_large_graph)
            
            if gate_work == True:
                batch_subgraphs = self.sample_subgraphs_rep(single_large_graph)
                graph_list_all += batch_subgraphs
        batch_graph = Batch.from_data_list(large_graph_list)
        
        if gate_work == True:
            # print(colored('begin training gates', 'green', attrs=['bold']))

            # for sub in graph_list_all:
            #     sub_x = self.gin_model(sub.x, sub.edge_index)
            #     sub_x = global_mean_pool(sub.x,batch = torch.zeros(sub.x.size(0), dtype=torch.long, device=sub.x.device))
            #     x_list.append(sub_x)

            # batch_x=torch.cat(x_list, dim=0)
            # batch_x = F.relu(self.gate1(batch_x))
            # gate_logits = self.gate2(batch_x)
            # gate_logits = torch.sigmoid(gate_logits)


            batch_sub = Batch.from_data_list(graph_list_all)
            sub_x = self.gin_model(batch_sub.x, batch_sub.edge_index)
            sub_x = global_mean_pool(sub_x, batch_sub.batch)
            batch_x = F.relu(self.gate1(sub_x))
            gate_logits = torch.sigmoid(self.gate2(batch_x))




            # x_list = []
            # for sub in graph_list_all:
            #     sub_x = self.gin_model(sub.x, sub.edge_index)
            #     pooled = global_mean_pool(sub_x,batch = torch.zeros(sub.x.size(0), dtype=torch.long, device=sub.x.device))
            #     x_list.append(pooled)
            # concat_1 = torch.cat(x_list, dim=0)


            # batch_sub = Batch.from_data_list(graph_list_all)
            # sub_x = self.gin_model(batch_sub.x, batch_sub.edge_index)
            # concat_2 = global_mean_pool(sub_x, batch_sub.batch)




            after_edit = self.differentiable_edge_gating(batch_graph, gate_logits, temperature=0.2)
            # pdb.set_trace()
        else:
            after_edit = batch_graph
            after_edit.edge_weight = torch.ones((batch_graph.edge_index.shape[1],1),device=batch_graph.x.device)
            gate_logits = []

        head_x = self.gin_model_head_1st(after_edit.x, after_edit.edge_index, after_edit.edge_weight)

        if self.use_jk:
            xs = [head_x]
            for conv in self.gin_model_head:
                head_x = conv(head_x, after_edit.edge_index, after_edit.edge_weight)
                xs += [head_x]
            x = F.gelu(self.lin1(self.jump(xs)))
        else:
            for conv in self.gin_model_head:
                head_x = conv(head_x, after_edit.edge_index, after_edit.edge_weight)
            x = F.gelu(self.lin1(head_x))

        x = global_mean_pool(x,batch = batch_graph.batch)
        final_x = self.head(x)
        
        return final_x, gate_logits
        
class InactivePromptBinding(torch.nn.Module):
    def __init__(self, in_len=2000, hidden=13, gin_num_layer = 2, num_token = 10, use_jk = True, use_GRU = True,method = 'xavier', device=None):
    
        super(InactivePromptBinding, self).__init__()
        self.device = device
        self.hidden = hidden
        print(self.hidden)
        ## protein embedding module
        self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=3, padding=0)
        self.bn1 = nn.BatchNorm1d(1)
        self.biGRU = nn.GRU(1, 1, bidirectional=True, batch_first=True, num_layers=1)
        self.maxpool1d = nn.MaxPool1d(3, stride=3)
        self.global_avgpool1d = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(math.floor(in_len / 3), self.hidden)
        
        self.num_token = num_token
        self.A_B_token = torch.nn.Parameter(torch.empty(num_token, self.hidden))


        if method == 'kaiming':
            torch.nn.init.kaiming_uniform_(self.A_B_token, nonlinearity='relu')

        if method == 'normal':
            torch.nn.init.normal_(self.A_B_token, mean=0.0, std=0.02)

        if method == 'zeros':
            torch.nn.init.zeros_(self.A_B_token)
        if method == 'xavier':
            torch.nn.init.xavier_uniform_(self.A_B_token)

        # torch.nn.init.kaiming_uniform_(self.A_B_token, nonlinearity='leaky_relu', mode='fan_in', a=0.01)
        self.use_jk = use_jk
        self.use_GRU = use_GRU
        self.jump = JumpingKnowledge('cat')
        if use_jk:
            self.lin1 = nn.Linear(gin_num_layer*self.hidden, self.hidden)
        else:
            self.lin1 = nn.Linear(self.hidden, self.hidden)
        self.gin_model = GINConv( 
            nn.Sequential(
                nn.Linear(self.hidden, self.hidden),
                nn.ReLU(),
                nn.BatchNorm1d(self.hidden),
            ), eps=0.2, train_eps=True, edge_dim=1, add_self_loops=True
        )
        self.gin_model_head_1st = GINEConv( 
            nn.Sequential(
                nn.Linear(self.hidden, self.hidden),
                nn.ReLU(),
                nn.Linear(self.hidden, self.hidden),
                nn.ReLU(),
                nn.BatchNorm1d(self.hidden),
            ), eps=0.2, train_eps=True, edge_dim=1, add_self_loops=True
        )
        self.gin_model_head = torch.nn.ModuleList()
        for i in range(gin_num_layer - 1):
            self.gin_model_head.append(
                GINEConv(
                    nn.Sequential(
                        nn.Linear(self.hidden, self.hidden),
                        nn.ReLU(),
                        nn.Linear(self.hidden, self.hidden),
                        nn.ReLU(),
                        nn.BatchNorm1d(self.hidden),
                    ), eps=0.2, train_eps=True, edge_dim=1, add_self_loops=True
                )
            )

        self.gate1 = nn.Linear(self.hidden, self.hidden)
        self.gate2 = nn.Linear(self.hidden, 1)
        self.head = torch.nn.Linear(hidden, 1)

    def freeze_param_token(self):
        self.A_B_token.requires_grad = False
    
    def freeze_param_gin(self):


        for param in self.gin_model.parameters():
            param.requires_grad = False
        for gin_head in self.gin_model_head:
            for param in gin_head.parameters():
                param.requires_grad = False
        for param in self.gin_model_head_1st.parameters():
            param.requires_grad = False

        for param in self.gate1.parameters():
            param.requires_grad = True
        for param in self.gate2.parameters():
            param.requires_grad = True

        for param in self.head.parameters():
            param.requires_grad = True

    def trainable_param_gin(self):

        for param in self.gin_model.parameters():
            param.requires_grad = True
        for gin_head in self.gin_model_head:
            for param in gin_head.parameters():
                param.requires_grad = True
        for param in self.gin_model_head_1st.parameters():
            param.requires_grad = True

    def build_large_graph(self, x1, x2, token_emb):
        node_features = torch.cat([torch.stack([x1, x2], dim=0), token_emb], dim=0)    
        edge_list = []
        a_idx, b_idx = 0, 1
        for i in range(int(token_emb.shape[0] / 2)):
            x_idx = 2 + i * 2
            y_idx = 3 + i * 2
            

            edge_list.append([a_idx, x_idx])  # a->x_i
            edge_list.append([x_idx, y_idx])  # x_i->y_i
            edge_list.append([y_idx, b_idx])  # y_i->b

        edge_index = torch.tensor(edge_list, dtype=torch.long, device = x1.device).t()
        undirected_edge_index = to_undirected(edge_index)
        
        return Data(x=node_features, edge_index=undirected_edge_index)

    def reset_parameters(self):
        
        self.conv1d.reset_parameters()
        self.fc1.reset_parameters()
        self.gin_model.reset_parameters()
        self.head.reset_parameters()

    def diff(self, logits, probabilities):
        if self.training:

            uniforms = torch.rand_like(probabilities)
            gumbels = -torch.log(-torch.log(uniforms + 1e-10) + 1e-10)
            samples = torch.sigmoid((logits + gumbels) / 0.1)
            decisions = (samples > 0.5).float()
        else:

            decisions = (probabilities > 0.5).float()

        return probabilities, decisions

    def sample_subgraphs(self, original_data):
        subgraphs = []
        x = original_data.x
        edge_index = original_data.edge_index
        
        
        for i in range(int((original_data.x.shape[0]-2)/2)):

            node_indices = torch.tensor([0, 1, 2 + 2*i, 3 + 2*i], dtype=torch.long)

            node_mask = torch.zeros(x.size(0), dtype=torch.bool)
            node_mask[node_indices] = True

            src_nodes = edge_index[0]
            dst_nodes = edge_index[1]
            edge_mask = node_mask[src_nodes] & node_mask[dst_nodes]
            subgraph_edge_index = edge_index[:, edge_mask]

            mapping = {idx.item(): local_idx for local_idx, idx in enumerate(node_indices)}
            
            subgraph_edge_index = subgraph_edge_index.to(x.device)
            subgraph = Data(
                x=x[node_indices],
                edge_index=subgraph_edge_index,

                num_nodes = 4
            )
            # pdb.set_trace()
            subgraphs.append(subgraph)
        
        return subgraphs


    def sample_subgraphs_rep(self, original_data):
        subgraphs = []
        x = original_data.x
        edge_index = original_data.edge_index
        for i in range(int((original_data.x.shape[0]-2)/2)):

            node_indices = torch.tensor([0, 1, 2 + 2*i, 3 + 2*i], dtype=torch.long)

            subgraph = Data(
                x=x[node_indices],
                edge_index=torch.tensor([[0, 1, 2, 2, 3, 3],[2, 3, 0, 3, 1, 2]], device=x.device),

                num_nodes = 4
            )
            # pdb.set_trace()
            subgraphs.append(subgraph)
        
        return subgraphs

    def differentiable_edge_gating(self, data, gate_logits, temperature=0.1):
        # pdb.set_trace()
        edge_src = data.edge_index[0]
        edge_batch = data.batch[edge_src]
        edge_gates = gate_logits.squeeze(1)[edge_batch]
        u = torch.rand_like(edge_gates)
        gumbel_noise = -torch.log(-torch.log(u + 1e-10) + 1e-10)
        gumbel_logits = torch.log(edge_gates + 1e-10) - torch.log(1 - edge_gates + 1e-10) + gumbel_noise
        edge_probs = torch.sigmoid(gumbel_logits / temperature)
        hard_decisions = (edge_gates > 0.5).float()
        edge_weights = hard_decisions - edge_probs.detach() + edge_probs
        data.edge_weight = edge_weights.unsqueeze(1)
        return data

    def forward(self, x, edge_index, train_edge_id, active_prune, gate_work):
        graph_list_all, x_list, large_graph_list = [],[],[]
        x = x.transpose(1, 2)
        x = self.conv1d(x)
        x = self.bn1(x)
        x = self.maxpool1d(x)
        if self.use_GRU:
            x = x.transpose(1, 2)
            x, _ = self.biGRU(x)
            x = self.global_avgpool1d(x)
        x = x.squeeze()
        x = self.fc1(x)
        x = F.dropout(x, p=0.3, training=self.training)
        node_id = train_edge_id
        
        for i in range(node_id.shape[1]):
            x1 = x[node_id[0, i]]
            x2 = x[node_id[1, i]]
            single_large_graph = self.build_large_graph(x1, x2, self.A_B_token)
            large_graph_list.append(single_large_graph)
            
            if gate_work == True:
                batch_subgraphs = self.sample_subgraphs_rep(single_large_graph)
                graph_list_all += batch_subgraphs
        batch_graph = Batch.from_data_list(large_graph_list)
        
        if gate_work == True:
            # print(colored('begin training gates', 'green', attrs=['bold']))

            # for sub in graph_list_all:
            #     sub_x = self.gin_model(sub.x, sub.edge_index)
            #     sub_x = global_mean_pool(sub.x,batch = torch.zeros(sub.x.size(0), dtype=torch.long, device=sub.x.device))
            #     x_list.append(sub_x)

            # batch_x=torch.cat(x_list, dim=0)
            # batch_x = F.relu(self.gate1(batch_x))
            # gate_logits = self.gate2(batch_x)
            # gate_logits = torch.sigmoid(gate_logits)


            batch_sub = Batch.from_data_list(graph_list_all)
            sub_x = self.gin_model(batch_sub.x, batch_sub.edge_index)
            sub_x = global_mean_pool(sub_x, batch_sub.batch)
            batch_x = F.relu(self.gate1(sub_x))
            gate_logits = torch.sigmoid(self.gate2(batch_x))




            # x_list = []
            # for sub in graph_list_all:
            #     sub_x = self.gin_model(sub.x, sub.edge_index)
            #     pooled = global_mean_pool(sub_x,batch = torch.zeros(sub.x.size(0), dtype=torch.long, device=sub.x.device))
            #     x_list.append(pooled)
            # concat_1 = torch.cat(x_list, dim=0)


            # batch_sub = Batch.from_data_list(graph_list_all)
            # sub_x = self.gin_model(batch_sub.x, batch_sub.edge_index)
            # concat_2 = global_mean_pool(sub_x, batch_sub.batch)




            after_edit = self.differentiable_edge_gating(batch_graph, gate_logits, temperature=0.2)
            # pdb.set_trace()
        else:
            after_edit = batch_graph
            after_edit.edge_weight = torch.ones((batch_graph.edge_index.shape[1],1),device=batch_graph.x.device)
            gate_logits = []

        head_x = self.gin_model_head_1st(after_edit.x, after_edit.edge_index, after_edit.edge_weight)

        if self.use_jk:
            xs = [head_x]
            for conv in self.gin_model_head:
                head_x = conv(head_x, after_edit.edge_index, after_edit.edge_weight)
                xs += [head_x]
            x = F.gelu(self.lin1(self.jump(xs)))
        else:
            for conv in self.gin_model_head:
                head_x = conv(head_x, after_edit.edge_index, after_edit.edge_weight)
            x = F.gelu(self.lin1(head_x))

        x = global_mean_pool(x,batch = batch_graph.batch)
        final_x = self.head(x)
        
        return final_x, gate_logits