import torch
import torch.optim as optim
from torch.autograd import Variable
from torch_geometric.loader import DataLoader
from torch.utils.data import TensorDataset
from prompt_graph.data import load4link_prediction_single_graph
from torch.optim import Adam
import time
from .base import PreTrain
import os

class Edgepred_GPPT(PreTrain):
    def __init__(self, *args, **kwargs):    
        super().__init__(*args, **kwargs)  
        #self.dataloader = self.generate_loader_data()
        self.train_dataloader, self.test_dataloader = self.generate_loader_data()
        self.initialize_gnn(self.input_dim, self.hid_dim) 
        self.graph_pred_linear = torch.nn.Linear(self.hid_dim, self.output_dim).to(self.device)  

    def generate_loader_data(self):
        if self.dataset_name in ['PubMed', 'CiteSeer', 'Cora', 'Computers', 'Photo','ogbn-arxiv', 'Flickr', 'Actor', 'Texas', 'Wisconsin']:
            self.data, edge_label, edge_index, self.input_dim, self.output_dim = load4link_prediction_single_graph(self.dataset_name, use_different_dataset=self.use_different_dataset)  
            self.data.to(self.device) 
            edge_index = edge_index.transpose(0, 1)
            # split data into train_data and test_data
            train_edge_label = edge_label[:int(len(edge_label)*0.8)]
            train_edge_index = edge_index[:int(len(edge_index)*0.8)]
            test_edge_label = edge_label[int(len(edge_label)*0.8):]
            test_edge_index = edge_index[int(len(edge_index)*0.8):]

            #data = TensorDataset(edge_label, edge_index)
            train_data = TensorDataset(train_edge_label, train_edge_index)
            test_data = TensorDataset(test_edge_label, test_edge_index)
            if self.dataset_name in['ogbn-arxiv', 'Flickr']:
                return DataLoader(train_data, batch_size=1024, shuffle=True), DataLoader(test_data, batch_size=1024, shuffle=False)
            else:
                #return DataLoader(data, batch_size=64, shuffle=True)
                return DataLoader(train_data, batch_size=64, shuffle=True), DataLoader(test_data, batch_size=64, shuffle=False)
      
    def pretrain_one_epoch(self):

        accum_loss, total_step = 0, 0
        device = self.device

        criterion = torch.nn.BCEWithLogitsLoss()
        
        self.gnn.train()
        for step, (batch_edge_label, batch_edge_index) in enumerate(self.train_dataloader):
            self.optimizer.zero_grad()

            batch_edge_label = batch_edge_label.to(device)
            batch_edge_index = batch_edge_index.to(device)

            if self.dataset_name in ['COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY', 'ogbg-ppa', 'DD']:
                for batch_id, batch_graph in enumerate(self.batch_dataloader):
                    batch_graph.to(device)
                    if(batch_id==0):
                        out = self.gnn(batch_graph.x, batch_graph.edge_index)
                    else:
                        out = torch.concatenate([out, self.gnn(batch_graph.x, batch_graph.edge_index)],dim=0)
            else:
                out = self.gnn(self.data.x, self.data.edge_index)
                
            
            node_emb = self.graph_pred_linear(out)
          
            batch_edge_index = batch_edge_index.transpose(0,1)
            batch_pred_log = self.gnn.decode(node_emb,batch_edge_index).view(-1)
            loss = criterion(batch_pred_log, batch_edge_label)

            loss.backward()
            self.optimizer.step()

            accum_loss += float(loss.detach().cpu().item())
            total_step += 1
            
            # print('第{}次反向传播过程'.format(step))
        return accum_loss / total_step

    def test(self):
        accum_loss, total_step = 0, 0
        device = self.device

        criterion = torch.nn.BCEWithLogitsLoss()
        
        self.gnn.eval()
        for step, (batch_edge_label, batch_edge_index) in enumerate(self.test_dataloader):

            batch_edge_label = batch_edge_label.to(device)
            batch_edge_index = batch_edge_index.to(device)

            if self.dataset_name in ['COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY', 'ogbg-ppa', 'DD']:
                for batch_id, batch_graph in enumerate(self.batch_dataloader):
                    batch_graph.to(device)
                    if(batch_id==0):
                        out = self.gnn(batch_graph.x, batch_graph.edge_index)
                    else:
                        out = torch.concatenate([out, self.gnn(batch_graph.x, batch_graph.edge_index)],dim=0)
            else:
                out = self.gnn(self.data.x, self.data.edge_index)
                
            
            node_emb = self.graph_pred_linear(out)
          
            batch_edge_index = batch_edge_index.transpose(0,1)
            batch_pred_log = self.gnn.decode(node_emb,batch_edge_index).view(-1)
            loss = criterion(batch_pred_log, batch_edge_label)

            accum_loss += float(loss.detach().cpu().item())
            total_step += 1
            
        return accum_loss / total_step

    def pretrain(self):
        num_epoch = self.epochs
        train_loss_min = 1000000
        patience = 20
        cnt_wait = 0


        file_path = f"./dataspace/pre_train_results/{self.dataset_name}"
        if not os.path.exists(file_path):
            os.makedirs(file_path)
                 
        for epoch in range(1, num_epoch + 1):
            st_time = time.time()
            train_loss = self.pretrain_one_epoch()
            test_loss = self.test()
            print(f"Edgepred_GPPT [Pretrain] Epoch {epoch}/{num_epoch} | Train Loss {train_loss:.5f} | Test Loss {test_loss:.5f} "
                  f"Cost Time {time.time() - st_time:.3}s")

            filename = "Edgepred_GPPT.{}.{}hidden_dim.seed{}.txt".format(self.gnn_type, str(self.hid_dim), self.seed)
            save_path = os.path.join(file_path, filename)
            if (epoch == 1) and (os.path.exists(save_path)): 
                os.remove(save_path) 
            with open(save_path, 'a') as f:
                f.write('%d %.8f %.8f'%(epoch, train_loss, test_loss))
                f.write("\n")
            
            if train_loss_min > train_loss:
                train_loss_min = train_loss
                cnt_wait = 0
            else:
                cnt_wait += 1
                if cnt_wait == patience:
                    print('-' * 100)
                    print('Early stopping at '+str(epoch) +' eopch!')
                    break
            print(cnt_wait)
            

        folder_path = f"./dataspace/pre_trained_model/{self.dataset_name}"
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        torch.save(self.gnn.state_dict(),
                    "{}/{}.{}.{}.pth".format(folder_path, 'Edgepred_GPPT', self.gnn_type, str(self.hid_dim) + 'hidden_dim'))
        print("+++model saved ! {}/{}.{}.{}.pth".format(self.dataset_name, 'Edgepred_GPPT', self.gnn_type, str(self.hid_dim) + 'hidden_dim'))
