import torch
import torch.nn.functional as F
from torch_geometric.data import Batch, Data
from prompt_graph.utils import act
from deprecated.sphinx import deprecated
from sklearn.cluster import KMeans
from torch_geometric.nn.inits import glorot
import pdb
from torch_geometric.utils import to_undirected
import torch.nn as nn


COSINE = 1


def tempsigmoid(x, temp):
    return torch.sigmoid(x/(temp)) 


class LightPrompt(torch.nn.Module):
    def __init__(self, token_dim, token_num_per_group, group_num=1, inner_prune=None):
        """
        :param token_dim:
        :param token_num_per_group:
        :param group_num:   the total token number = token_num_per_group*group_num, in most cases, we let group_num=1.
                            In prompt_w_o_h mode for classification, we can let each class correspond to one group.
                            You can also assign each group as a prompt batch in some cases.

        :param prune_thre: if inner_prune is None, then all inner and cross prune will adopt this prune_thre
        :param isolate_tokens: if Trure, then inner tokens have no connection.
        :param inner_prune: if inner_prune is not None, then cross prune adopt prune_thre whereas inner prune adopt inner_prune
        """
        super(LightPrompt, self).__init__()

        self.inner_prune = inner_prune

        self.token_list = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(token_num_per_group, token_dim)) for i in range(group_num)])

        self.token_init(init_method="kaiming_uniform")

    def token_init(self, init_method="kaiming_uniform"):
        if init_method == "kaiming_uniform":
            for token in self.token_list:
                torch.nn.init.kaiming_uniform_(token, nonlinearity='leaky_relu', mode='fan_in', a=0.01)/10
        else:
            raise ValueError("only support kaiming_uniform init, more init methods will be included soon")

    def inner_structure_update(self):
        return self.token_view()

    def token_view(self, ):
        """
        each token group is viewed as a prompt sub-graph.
        turn the all groups of tokens as a batch of prompt graphs.
        :return:
        """
        pg_list = []
        for i, tokens in enumerate(self.token_list):
            # inner link: token-->token
            token_dot = torch.mm(tokens, torch.transpose(tokens, 0, 1))
            # set the diagonal elements to -10, so that the diagonal elements will not be selected.
            token_dot = torch.where(torch.eye(tokens.shape[0]).to(token_dot.device) == 1, torch.tensor(-10.0).to(token_dot.device), token_dot)
            token_sim = tempsigmoid(token_dot, 0.2)  # 0-1

            inner_adj = torch.where(token_sim < self.inner_prune, 0, token_sim)
            edge_index = inner_adj.nonzero().t().contiguous()

            pg_list.append(Data(x=tokens, edge_index=edge_index, y=torch.tensor([i]).long()))

        pg_batch = Batch.from_data_list(pg_list)
        return pg_batch
    

    def inner_structure_update_cosine(self):
        pg_list = []
        for i, tokens in enumerate(self.token_list):
            # calculate the cosine similarity between tokens
            token_dot = torch.mm(tokens, torch.transpose(tokens, 0, 1))
            token_norm = torch.norm(tokens, dim=1).unsqueeze(1)
            token_norm = torch.mm(token_norm, torch.transpose(token_norm, 0, 1))
            token_sim = token_dot / token_norm

            # set the diagonal elements to -10, so that the diagonal elements will not be selected.
            token_sim = torch.where(torch.eye(tokens.shape[0]).to(token_sim.device) == 1, torch.tensor(-10.0).to(token_sim.device), token_sim)
            inner_adj = torch.where(token_sim < self.inner_prune, 0, token_sim)
            edge_index = inner_adj.nonzero().t().contiguous()

            pg_list.append(Data(x=tokens, edge_index=edge_index, y=torch.tensor([i]).long()))
        
        pg_batch = Batch.from_data_list(pg_list)
        return pg_batch
        

class MorpherGraphPrompt(LightPrompt):
    def __init__(self, token_dim, token_num, cross_prune=0.1, inner_prune=0.6):
        super(MorpherGraphPrompt, self).__init__(token_dim, token_num, 1, inner_prune)  # only has one prompt graph.
        self.cross_prune = cross_prune

    def forward(self, graph_batch: Batch):
        """
        TODO: although it recieves graph batch, currently we only implement one-by-one computing instead of batch computing
        TODO: we will implement batch computing once we figure out the memory sharing mechanism within PyG
        :param graph_batch:
        :return:
        """

        if COSINE == 0:
            pg = self.inner_structure_update()  # batch of prompt graph (currently only 1 prompt graph in the batch)
        else:
            pg = self.inner_structure_update_cosine()

        inner_edge_index = pg.edge_index
        token_num = pg.x.shape[0]

        re_graph_list = []
        for g in Batch.to_data_list(graph_batch):
            g_edge_index = g.edge_index + token_num
            
            num_edges_in_g = g_edge_index.shape[1]
            num_nodes_in_g = g.x.shape[0]
            average_degree = num_edges_in_g / num_nodes_in_g

            
            cross_dot = torch.mm(pg.x, torch.transpose(g.x, 0, 1))
            
            if COSINE == 0:
                cross_dot = torch.where(cross_dot < 0, torch.tensor(-10.0).to(cross_dot.device), cross_dot)
                cross_sim = torch.sigmoid(cross_dot)  # 0-1 from prompt to input graph

            else:
                cross_norm = torch.norm(pg.x, dim=1).unsqueeze(1)
                cross_norm2 = torch.norm(g.x, dim=1).unsqueeze(1)
                # pdb.set_trace()
                cross_norm = torch.mm(cross_norm, torch.transpose(cross_norm2, 0, 1))
                cross_sim = cross_dot / cross_norm
                
            average_degree = average_degree // 2

            # cross_sim = torch.where(cross_sim < torch.topk(cross_sim, min(g.x.shape[0], max(int(average_degree), 1)), dim=0)[0][:, -1].unsqueeze(1), 0, cross_sim)
            topk_k = min(g.x.shape[0], max(int(average_degree), 1))
            cross_sim = torch.where(cross_sim < torch.topk(cross_sim, topk_k, dim=0)[0][-1, :], 0, cross_sim)

            cross_adj = torch.where(cross_sim < self.cross_prune, 0, cross_sim)
            


            cross_edge_index = cross_adj.nonzero().t().contiguous()
            cross_edge_index[1] = cross_edge_index[1] + token_num

            cross_edge_index = to_undirected(cross_edge_index)
            
            x = torch.cat([pg.x, g.x], dim=0)
            y = g.y

            edge_index = torch.cat([inner_edge_index, g_edge_index, cross_edge_index], dim=1)
            data = Data(x=x, edge_index=edge_index, y=y)
            re_graph_list.append(data)

        graphp_batch = Batch.from_data_list(re_graph_list)
        return graphp_batch
    

    def Tune(self, train_loader, gnn, answering, lossfn, opi, device):
        running_loss = 0.
        for batch_id, train_batch in enumerate(train_loader):  
            # print(train_batch)
            train_batch = train_batch.to(device)
            prompted_graph = self.forward(train_batch)
            # print(prompted_graph)

            graph_emb = gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch)
            pre = answering(graph_emb)
            train_loss = lossfn(pre, train_batch.y)

            opi.zero_grad()
            train_loss.backward()
            opi.step()
            running_loss += train_loss.item()

        return running_loss / len(train_loader)
    
    def TuneWithoutAnswering(self, train_loader, gnn, answering, lossfn, opi, device):
        total_loss = 0.0 
        for batch in train_loader:
            self.optimizer.zero_grad()
            batch = batch.to(self.device)
            emb0 = gnn(batch.x, batch.edge_index, batch.batch)
            pg_batch = self.inner_structure_update()
            pg_batch = pg_batch.to(self.device)
            pg_emb = gnn(pg_batch.x, pg_batch.edge_index, pg_batch.batch)
            # cross link between prompt and input graphs
            dot = torch.mm(emb0, torch.transpose(pg_emb, 0, 1))
            sim = torch.softmax(dot, dim=1)
            loss = lossfn(sim, batch.y)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()  
        return total_loss / len(train_loader) 
    
    def TuneOnlyAnswering(self, train_loader, gnn, answering, lossfn, opi, device):
        running_loss = 0.
        for batch_id, train_batch in enumerate(train_loader):  
            # print(train_batch)
            train_batch = train_batch.to(device)
            # print(prompted_graph)

            graph_emb = gnn(train_batch.x, train_batch.edge_index, train_batch.batch)
            pre = answering(graph_emb)
            train_loss = lossfn(pre, train_batch.y)

            opi.zero_grad()
            train_loss.backward()
            opi.step()
            running_loss += train_loss.item()

        return running_loss / len(train_loader)





class MorpherTextPrompt(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                n_tokens: int = 10, 
                random_range: float = 0.5,
                initialize_from_vocab: bool = True,
                start_vocab_emb = None):
        """appends learned embedding to 

        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(MorpherTextPrompt, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens, 
                                                                               random_range, 
                                                                               initialize_from_vocab,
                                                                               start_vocab_emb))
            
    def initialize_embedding(self, 
                             wte: nn.Embedding,
                             n_tokens: int = 10, 
                             random_range: float = 0.5, 
                             initialize_from_vocab: bool = True,
                             start_vocab_emb = None):
        """initializes learned embedding

        Args:
            same as __init__

        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            if start_vocab_emb is None:
                return self.wte.weight[:n_tokens].clone().detach()
            else:
                return start_vocab_emb
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
            
    def forward(self, tokens):
        """run forward pass

        Args:
            tokens (torch.long): input tokens before encoding

        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)


