import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv, global_add_pool, SAGEConv
from typing import Optional
from torch import Tensor
#from torch_geometric.utils import scatter
from torch_scatter import scatter
import random


class serverGIN_1(torch.nn.Module):
    def __init__(self, nlayer, nhid):
        super(serverGIN, self).__init__()
        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(),
                                       torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))

class GIN_1(torch.nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayer, dropout):
        super(GIN, self).__init__()
        self.num_layers = nlayer
        self.dropout = dropout

        self.pre = torch.nn.Sequential(torch.nn.Linear(nfeat, nhid))
        self.proto = torch.nn.Sequential(torch.nn.Linear(nhid, nhid),torch.nn.ReLU(),torch.nn.Linear(nhid, nhid))
        self.proto_clasif = torch.nn.Linear(nhid, 2)

        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))

        self.post = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU())
        self.readout = torch.nn.Sequential(torch.nn.Linear(nhid, nclass))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        z_0 = self.pre(x)
        z_0 = self.proto(z_0)
        z = z_0*1

        for i in range(len(self.graph_convs)):
            z = self.graph_convs[i](z, edge_index)
            z = F.relu(z)
            z = F.dropout(z, self.dropout, training=self.training)
        #z = global_add_pool(z, batch) #change pool method
        z = vcp_pool(z, batch)
        z = self.post(z)
        z = F.dropout(z, self.dropout, training=self.training)
        z = self.readout(z)
        z = F.log_softmax(z, dim=1)


        z_proto = self.proto_clasif(z_0)
        return z, z_proto

   # def loss(self, pred, pred_proto, label, proto_label):
   #     return F.nll_loss(pred, label)+F.nll_loss(pred_proto, proto_label)
    def loss(self, pred, pred_proto, label, proto_label):
        proto_label = proto_label.repeat(pred_proto.shape[0])
        return F.nll_loss(pred, label)+0.01*F.nll_loss(pred_proto, proto_label)


#fedkalter
class serverGIN(torch.nn.Module):
    def __init__(self, n_se, nlayer, nhid):
        super(serverGIN, self).__init__()
        self.graph_convs_A = torch.nn.ModuleList()
        self.graph_convs_H = torch.nn.ModuleList()

        self.nn1_A = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))#false
        self.graph_convs_A.append(GINConv(self.nn1_A))

        self.nn1_H = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))#false
        self.graph_convs_H.append(GINConv(self.nn1_H))


        for l in range(nlayer - 1):
            self.nnk_A = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))#false
            self.nnk_H = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))#false

            self.graph_convs_A.append(GINConv(self.nnk_A))
            self.graph_convs_H.append(GINConv(self.nnk_H))




class GIN(torch.nn.Module):
    def __init__(self, nfeat, n_se, nhid, nclass, nlayer, dropout):
        super(GIN, self).__init__()
        self.num_layers = nlayer
        self.dropout = dropout

        self.pre = torch.nn.Sequential(torch.nn.Linear(nfeat, nhid))
        self.embedding_s = torch.nn.Linear(n_se, nhid)

        self.graph_convs_A = torch.nn.ModuleList()

        self.nn1_A = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs_A.append(GINConv(self.nn1_A))

        self.graph_convs_F = torch.nn.ModuleList()
        self.graph_convs_F.append(GCNConv(nhid,nhid))

        self.graph_convs_H = torch.nn.ModuleList()
        self.nn1_H = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs_H.append(GINConv(self.nn1_H))


        for l in range(nlayer - 1):
            self.nnk_A = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs_A.append(GINConv(self.nnk_A))
            
            self.nnk_H = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs_H.append(GINConv(self.nnk_H))

            self.graph_convs_F.append(GCNConv(nhid,nhid))

        self.post = torch.nn.Sequential(torch.nn.Linear(nhid+nhid, nhid), torch.nn.ReLU())
        self.readout = torch.nn.Sequential(torch.nn.Linear(nhid, nclass))


    def forward(self, data):
        x, edge_index, batch, s = data.x, data.edge_index, data.batch, data.stc_enc
        
        noise = self.pre(x)
        noise = F.normalize(noise, dim=1)

        s = self.embedding_s(s)
        
        #noise = self.graph_convs_A[0](noise, edge_index)
        x_k = self.graph_convs_A[0](s-noise, edge_index)# + noise
        x_k = F.relu(x_k)
        x_k = F.dropout(x_k, self.dropout, training = self.training)

        noise = self.graph_convs_F[0](noise, edge_index)
        noise = F.relu(noise)
        noise = F.dropout(noise, self.dropout, training = self.training)


        z_k = self.graph_convs_H[0](x_k, edge_index) #+ self.noise_b[i]


        x_kn = self.graph_convs_A[0](s, edge_index) #?=tilid{x_{k-1}}
        x_kn = F.relu(x_kn)
        x_kn = F.dropout(x_kn, self.dropout, training = self.training)
        z_kn = self.graph_convs_H[0](x_kn, edge_index)

        #x_kn = x_kn + self.K[0](z_k-z_kn)

        cov = x_k-(x_kn + 0.5*(z_k-z_kn))
        cov = torch.cov(cov.T)
        self.cov_loss = torch.norm(cov)
        x_kn = x_kn + 0.5*(z_k-z_kn)

        #s = torch.tanh(z_k)
        for i in range(1,len(self.graph_convs_A)):
            #noise = self.graph_convs_A[i](noise, edge_index)
            x_k = self.graph_convs_A[i](x_k-noise, edge_index)# + noise
            x_k = F.relu(x_k)
            x_k = F.dropout(x_k, self.dropout, training = self.training)

            noise = self.graph_convs_F[i](noise, edge_index)
            noise = F.relu(noise)
            noise = F.dropout(noise, self.dropout, training = self.training)

            z_k = self.graph_convs_H[i](x_k, edge_index) #+ self.noise_b[i]


            x_kn = self.graph_convs_A[i](x_kn, edge_index)
            x_kn = F.relu(x_kn)
            x_kn = F.dropout(x_kn, self.dropout, training = self.training)
            z_kn = self.graph_convs_H[i](x_kn, edge_index)
            # x_kn = x_kn + self.K[i](z_k-z_kn)

            #loss_k = x_k - (x_kn + K(z_k-z_kn))
            cov = x_k-(x_kn + 0.5*(z_k-z_kn))
            cov = torch.cov(cov.T)
            self.cov_loss = self.cov_loss + torch.norm(cov) #+ torch.norm(z_k-z_kn)
            x_kn = x_kn + 0.5*(z_k-z_kn)
        
        o = global_add_pool(torch.cat((x_kn, noise), -1), batch)
        o = self.post(o)
        o = F.dropout(o, self.dropout, training=self.training)
        o = self.readout(o)
        o = F.log_softmax(o, dim=1)
        
        return o

    def loss(self, pred, label):
        return F.nll_loss(pred, label)+self.cov_loss

def vcp_pool(x: Tensor, batch: Optional[Tensor],
            size: Optional[int] = None) -> Tensor:
    dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2

    if batch is None:
        return torch.prod(x+0.1,dim=dim)
            #return x.sum(dim=dim, keepdim=x.dim() <= 2)
    return scatter(x+0.1, batch, dim=dim, dim_size=size, reduce='mul') 



class serverGIN_dc(torch.nn.Module):
    def __init__(self, n_se, nlayer, nhid):
        super(serverGIN_dc, self).__init__()

        self.embedding_s = torch.nn.Linear(n_se, nhid)
        self.Whp = torch.nn.Linear(nhid + nhid, nhid)

        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        self.graph_convs_s_gcn = torch.nn.ModuleList()
        self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))
            self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

class GIN_dc(torch.nn.Module):
    def __init__(self, nfeat, n_se, nhid, nclass, nlayer, dropout):
        super(GIN_dc, self).__init__()
        self.num_layers = nlayer
        self.dropout = dropout

        self.pre = torch.nn.Sequential(torch.nn.Linear(nfeat, nhid))

        self.proto = torch.nn.Sequential(torch.nn.Linear(nhid, nhid),torch.nn.ReLU(),torch.nn.Linear(nhid, nhid))
        self.proto_clasif = torch.nn.Linear(nhid, 4)


        self.embedding_s = torch.nn.Linear(n_se, nhid)

        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        self.graph_convs_s_gcn = torch.nn.ModuleList()
        self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))
            self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

        self.Whp = torch.nn.Linear(nhid + nhid, nhid)
        self.post = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU())
        self.readout = torch.nn.Sequential(torch.nn.Linear(nhid, nclass))

    def forward(self, data):
        x, edge_index, batch, s = data.x, data.edge_index, data.batch, data.stc_enc

        z_0 = self.pre(x)

        z_0 = self.proto(z_0)
        z = z_0*1


        s = self.embedding_s(s)
        for i in range(len(self.graph_convs)):
            z = torch.cat((z, s), -1)
            z = self.graph_convs[i](z, edge_index)
            z = F.relu(z)
            z = F.dropout(z, self.dropout, training=self.training)
            s = self.graph_convs_s_gcn[i](s, edge_index)
            s = torch.tanh(s)

        z = self.Whp(torch.cat((z, s), -1))
        z = global_add_pool(z, batch)
        z = self.post(z)
        z = F.dropout(z, self.dropout, training=self.training)
        z = self.readout(z)
        z = F.log_softmax(z, dim=1)

        z_proto = self.proto_clasif(z_0)
        z_proto = F.log_softmax(z_proto, dim=1)

        return z,z_proto

    #def loss(self, pred, label):
    #    return F.nll_loss(pred, label)
    def loss(self, pred, pred_proto, label, proto_label):
        proto_label = proto_label.repeat(pred_proto.shape[0])
        return F.nll_loss(pred, label)+F.nll_loss(pred_proto, proto_label)


class serverGraphSage(torch.nn.Module):
    def __init__(self, nlayer, nhid):
        super(serverGraphSage, self).__init__()
        self.graph_convs = torch.nn.ModuleList()
        self.graph_convs.append(SAGEConv(nhid, nhid))
        for l in range(nlayer - 1):
            self.graph_convs.append(SAGEConv(nhid, nhid))

class GraphSage(torch.nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayer, dropout):
        super(GraphSage, self).__init__()
        self.num_layers = nlayer
        self.dropout = dropout

        self.pre = torch.nn.Sequential(torch.nn.Linear(nfeat, nhid))

        self.graph_convs = torch.nn.ModuleList()
        self.graph_convs.append(SAGEConv(nhid, nhid))

        for l in range(nlayer - 1):
            self.graph_convs.append(SAGEConv(nhid, nhid))

        self.post = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU())
        self.readout = torch.nn.Sequential(torch.nn.Linear(nhid, nclass))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.pre(x)
        for i in range(len(self.graph_convs)):
            x = self.graph_convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
        x = global_add_pool(x, batch)
        x = self.post(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.readout(x)
        x = F.log_softmax(x, dim=1)
        return x

    def loss(self, pred, label):
        return F.nll_loss(pred, label)
