from .base import PreTrain
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn.inits import reset, uniform
from torch.optim import Adam
import torch
from torch import nn
import time
from prompt_graph.utils import generate_corrupted_graph
from prompt_graph.data import load4node, NodePretrain
import os
import numpy as np
import copy

class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
        c_x = torch.unsqueeze(c, 0)
        c_x = c_x.expand_as(h_pl)

        sc_1 = self.f_k(h_pl, c_x)
        sc_2 = self.f_k(h_mi, c_x)

        if s_bias1 is not None:
            sc_1 += s_bias1
        if s_bias2 is not None:
            sc_2 += s_bias2

        logits = torch.cat((sc_1, sc_2), 1)

        return logits


class DGI(PreTrain):
    def __init__(self, *args, **kwargs):    # hid_dim=16
        super().__init__(*args, **kwargs)
        
      
        self.disc = Discriminator(self.hid_dim).to(self.device)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.graph_data = self.load_data()
        self.initialize_gnn(self.input_dim, self.hid_dim)  
        self.optimizer = Adam(self.gnn.parameters(), lr=0.001, weight_decay = 0.0)


    def load_data(self):
        if self.dataset_name in ['PubMed', 'CiteSeer', 'Cora','Computers', 'Photo', 'Reddit', 'WikiCS', 'Flickr', 'ogbn-arxiv','Actor', 'Texas', 'Wisconsin']:
            data, input_dim, _ = load4node(self.dataset_name, use_different_dataset=self.use_different_dataset)
            self.input_dim = input_dim

        return data

    def pretrain_one_epoch(self):
        self.gnn.train()
        self.optimizer.zero_grad()
        device = self.device

        if self.dataset_name in ['PubMed', 'CiteSeer', 'Cora','Computers', 'Photo', 'Reddit', 'WikiCS', 'Flickr', 'ogbn-arxiv']:
            graph_original = self.graph_data
            graph_corrupted = copy.deepcopy(graph_original)
            idx_perm = np.random.permutation(graph_original.x.size(0))
            graph_corrupted.x = graph_original.x[idx_perm].to(self.device)

            graph_original.to(device)
            graph_corrupted.to(device)

            pos_z = self.gnn(graph_original.x, graph_original.edge_index)
            neg_z = self.gnn(graph_corrupted.x, graph_corrupted.edge_index)

            s = torch.sigmoid(torch.mean(pos_z, dim=0)).to(device)

            logits = self.disc(s, pos_z, neg_z)

            lbl_1 = torch.ones((pos_z.shape[0], 1))
            lbl_2 = torch.zeros((neg_z.shape[0], 1))
            lbl = torch.cat((lbl_1, lbl_2), 1).to(device)

            loss = self.loss_fn(logits, lbl)
            loss.backward()
            self.optimizer.step()

            accum_loss = float(loss.detach().cpu().item())            

        return accum_loss    
            
    def pretrain(self):
        train_loss_min = 1000000
        patience = 20
        cnt_wait = 0


        file_path = f"./dataspace/pre_train_results/{self.dataset_name}"
        if not os.path.exists(file_path):
            os.makedirs(file_path)

        for epoch in range(1, self.epochs + 1):
            time0 = time.time()
            train_loss = self.pretrain_one_epoch()
            print("***epoch: {}/{} | train_loss: {:.8}".format(epoch, self.epochs , train_loss))

            filename = "DGI.{}.{}hidden_dim.seed{}.txt".format(self.gnn_type, str(self.hid_dim), self.seed)
            save_path = os.path.join(file_path, filename)
            # if save_path already exist, clear all existing contents
            if (epoch == 1) and (os.path.exists(save_path)): 
                os.remove(save_path) 
            with open(save_path, 'a') as f:
                f.write('%d %.8f'%(epoch, train_loss))
                f.write("\n")

            
            if train_loss_min > train_loss:
                train_loss_min = train_loss
                cnt_wait = 0
            else:
                cnt_wait += 1
                if cnt_wait == patience:
                    print('-' * 100)
                    print('Early stopping at '+str(epoch) +' eopch!')
                    break

        folder_path = f"./dataspace/pre_trained_model/{self.dataset_name}"
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        torch.save(self.gnn.state_dict(),
                    "{}/{}.{}.{}.pth".format(folder_path, 'DGI', self.gnn_type, str(self.hid_dim) + 'hidden_dim'))
        print("+++model saved ! {}/{}.{}.{}.pth".format(self.dataset_name, 'DGI', self.gnn_type, str(self.hid_dim) + 'hidden_dim'))
