import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, MLP
from torch_geometric.nn import pool
from torch_geometric.loader import DataLoader


class GCN_E(torch.nn.Module):
    def __init__(self, dim_z, dim_node_features, device='cuda'):
        self.dim_z = dim_z
        self.dim_node_features = dim_node_features
        self.device = device
        super().__init__()
        self.conv1 = GCNConv(self.dim_node_features, self.dim_z)
        self.conv2 = GCNConv(self.dim_z, self.dim_z)
        self.conv3 = GCNConv(self.dim_z, self.dim_z)
        self.conv4 = GCNConv(self.dim_z, self.dim_z)
        self.conv5 = GCNConv(self.dim_z, self.dim_z)
        self.conv6 = GCNConv(self.dim_z, self.dim_z)


    #! Take in a graph with u data as node features x
    #! Use pos as augmented node features
    #! Use edge_index and edge_attr as the graph structure
    def forward(self, u, graph):
        pos, edge_index, edge_attr = graph.pos, graph.edge_index, graph.edge_attr
        x = torch.concat([u, pos], dim=1)
        
        x = F.silu(self.conv1(x, edge_index, edge_weight=edge_attr))
        x = F.silu(self.conv2(x, edge_index, edge_weight=edge_attr))
        x = F.silu(self.conv3(x, edge_index, edge_weight=edge_attr))
        x = F.silu(self.conv4(x, edge_index, edge_weight=edge_attr))
        x = F.silu(self.conv5(x, edge_index, edge_weight=edge_attr))
        x = self.conv6(x, edge_index, edge_weight=edge_attr)
        x = pool.global_mean_pool(x, graph.batch)
        return x
    
    
    

class GCN_D(torch.nn.Module):
    def __init__(self, dim_z, dim_u, nfreq, dim_x=2, device='cuda'):
        self.dim_z = dim_z
        self.dim_u = dim_u
        self.nfreq = nfreq
        self.dim_x = dim_x
        self.device = device
        
        super().__init__()
        self.conv1 = GCNConv(self.dim_z+self.nfreq*2, self.dim_z)
        self.conv2 = GCNConv(self.dim_z, self.dim_z)
        self.conv3 = GCNConv(self.dim_z, self.dim_z)
        self.conv4 = GCNConv(self.dim_z, self.dim_z)
        self.conv5 = GCNConv(self.dim_z, self.dim_z)
        self.conv6 = GCNConv(self.dim_z, self.dim_u)
        
        nfreq = 25
        freqs = np.random.normal(loc=np.zeros(nfreq),
                     scale=np.exp(np.linspace(np.log(1e-2), np.log(1e1),nfreq)), 
                     size=(dim_x, nfreq) )
        freqs = torch.FloatTensor(freqs).to(self.device)
        

    def forward(self, z, graph):
        x_new = graph.x.clone()
        for i in range(graph.num_graphs):
            node_mask = (graph.batch == i)
            pos = graph.pos[node_mask]
            z_tile = z[i,:].repeat(pos.shape[0], 1)
            pos_emb = torch.einsum('nx, xd->nd', pos, self.freqs)
            x = torch.concat([z_tile,  torch.sin(torch.pi*2*pos_emb),  torch.cos(torch.pi*2*pos_emb)], dim=1)
            x_new[node_mask] =  x
        graph.x = x_new
        pos, edge_index, edge_attr = graph.pos, graph.edge_index, graph.edge_attr
        x = F.silu(self.conv1(graph.x, edge_index, edge_weight=edge_attr))
        x = F.silu(self.conv2(x, edge_index, edge_weight=edge_attr))
        x = F.silu(self.conv3(x, edge_index, edge_weight=edge_attr))
        x = F.silu(self.conv4(x, edge_index, edge_weight=edge_attr))
        x = F.silu(self.conv5(x, edge_index, edge_weight=edge_attr))
        x = self.conv6(x, edge_index, edge_weight=edge_attr)
        
        return x
    
    
    class TRAIN():
        def __init__(self, gcnE, gcnD, optimizerE, optimizerD, dim_z, nfreq, device='cuda'):
            self.gcnE = gcnE
            self.gcnD = gcnD
            self.optimizerE = optimizerE
            self.optimizerD = optimizerD
            self.dim_z = dim_z
            self.nfreq = nfreq
            self.device = device
        
        def loss_function(self, data):
            z = self.gcnE(data.x, data)
            data_clone = data.clone()
            data_clone.x = torch.zeros( ( data_clone.x.shape[0], self.dim_z+self.nfreq*2) ).to(self.device)
            u = self.gcnD(z, data_clone)
            lossL = torch.mean((data.x - u)**2.)
            Xd = torch.randn_like(z).to(self.device)
            LossD = self.mmd(z, Xd)
            loss = lossL + LossD
            return loss, (lossL, LossD)
        
        def fit(self):
            data_list = [data.to(self.device) for data in data_list]
            loader = DataLoader(data_list, batch_size=100, shuffle=True)
            data =  data.to(self.device)

            self.gcnE.train()
            self.gcnD.train()

            LOSS = []
            LOSS_D = []
            LOSS_L = []
            for epoch in range(5_000):
                for data in loader:
                    self.optimizerE.zero_grad()
                    self.optimizerD.zero_grad()
                    data.to(self.device)
                    loss, aux = self.loss_function(data)
                    lossL, LossD = aux
                    LOSS.append(loss.cpu().detach().numpy())
                    LOSS_D.append(LossD.cpu().detach().numpy())
                    LOSS_L.append(lossL.cpu().detach().numpy())
                    loss.backward()
                    self.optimizerE.step()
                    self.optimizerD.step()
                    
                    print("Epoch: ", epoch, "Loss: ", loss.item())
