import torch
import torch.optim as optim
from torch.autograd import Variable
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from prompt_graph.utils import mkdir
from torch.optim import Adam
from prompt_graph.data import load4node, NodePretrain
from copy import deepcopy
from.base import PreTrain
import os
import ipdb

class SimGRACE(PreTrain):

    def __init__(self, *args, **kwargs):    # hid_dim=16
        super().__init__(*args, **kwargs)
        self.load_graph_data()
        self.initialize_gnn(self.input_dim, self.hid_dim)   
        self.projection_head = torch.nn.Sequential(torch.nn.Linear(self.hid_dim, self.hid_dim),
                                                   torch.nn.ReLU(inplace=True),
                                                   torch.nn.Linear(self.hid_dim, self.hid_dim)).to(self.device)
        
    def load_graph_data(self):
        if self.dataset_name in ['PubMed', 'CiteSeer', 'Cora','Computers', 'Photo', 'Reddit', 'WikiCS', 'Flickr','ogbn-arxiv', 'Actor', 'Texas', 'Wisconsin']:
            self.graph_list, self.input_dim = NodePretrain(dataname = self.dataset_name, num_parts=500)
        
    def get_loader(self, graph_list, batch_size):

        if len(graph_list) % batch_size == 1:
            raise KeyError(
                "batch_size {} makes the last batch only contain 1 graph, \n which will trigger a zero bug in SimGRACE!")

        loader = DataLoader(graph_list, batch_size=batch_size, shuffle=False, num_workers=1)
        return loader
    
    def forward_cl(self, x, edge_index, batch):
        x = self.gnn(x, edge_index, batch)
        x = self.projection_head(x)
        return x

    def loss_cl(self, x1, x2):
        T = 0.1
        batch_size, _ = x1.size()
        x1_abs = x1.norm(dim=1)
        x2_abs = x2.norm(dim=1)
        sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs)
        sim_matrix = torch.exp(sim_matrix / T)
        pos_sim = sim_matrix[range(batch_size), range(batch_size)]
        loss = - torch.log(pos_sim / (sim_matrix.sum(dim=1) + 1e-4)).mean()
        return loss

    def perturbate_gnn(self, data):
        vice_model = deepcopy(self).to(self.device)

        for (vice_name, vice_model_param) in vice_model.named_parameters():
            if vice_name.split('.')[0] != 'projection_head':
                std = vice_model_param.data.std() if vice_model_param.data.numel() > 1 else torch.tensor(1.0)
                noise = 0.1 * torch.normal(0, torch.ones_like(vice_model_param.data) * std)
                vice_model_param.data += noise
        z2 = vice_model.forward_cl(data.x, data.edge_index, data.batch)
        return z2
    
    def train_simgrace(self, loader, optimizer):
        self.train()
        train_loss_accum = 0
        total_step = 0
        for data in loader:

            optimizer.zero_grad()
            data = data.to(self.device)
            x2 = self.perturbate_gnn(data) 
            x1 = self.forward_cl(data.x, data.edge_index, data.batch)
            x2 = Variable(x2.detach().data.to(self.device), requires_grad=False)
            loss = self.loss_cl(x1, x2)
            loss.backward()
            optimizer.step()
            train_loss_accum += float(loss.detach().cpu().item())
            total_step = total_step + 1
        return train_loss_accum / total_step

    def test_simgrace(self, loader):
        self.eval()
        test_loss_accum = 0
        total_step = 0
        for data in loader:
            data = data.to(self.device)
            x2 = self.perturbate_gnn(data) 
            x1 = self.forward_cl(data.x, data.edge_index, data.batch)
            x2 = Variable(x2.detach().data.to(self.device), requires_grad=False)
            loss = self.loss_cl(x1, x2)
            test_loss_accum += float(loss.detach().cpu().item())
            total_step = total_step + 1
        return test_loss_accum / total_step

    def pretrain(self, batch_size=10, lr=0.01,decay=0.0001):
        epochs = self.epochs
        if self.dataset_name in ['COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY', 'ogbg-ppa', 'DD']:
            batch_size = 512
        train_graph_list = self.graph_list[:int(len(self.graph_list)*0.8)]
        test_graph_list = self.graph_list[int(len(self.graph_list)*0.8):]
        train_loader= self.get_loader(train_graph_list, batch_size)
        test_loader = self.get_loader(test_graph_list, batch_size)

        print('start training {} | {} | {}...'.format(self.dataset_name, 'SimGRACE', self.gnn_type))
        optimizer = optim.Adam(self.gnn.parameters(), lr=lr, weight_decay=decay)

        train_loss_min = 1000000
        patience = 20
        cnt_wait = 0

        file_path = f"./dataspace/pre_train_results/{self.dataset_name}"
        if not os.path.exists(file_path):
            os.makedirs(file_path)

        for epoch in range(1, self.epochs + 1):  # 1..100

            train_loss = self.train_simgrace(train_loader, optimizer)
            test_loss = self.test_simgrace(test_loader)

            print("***epoch: {}/{} | train_loss: {:.8} | test_loss: {:.8}".format(epoch, self.epochs, train_loss, test_loss))

            filename = "SimGRACE.{}.{}hidden_dim.seed{}.txt".format(self.gnn_type, str(self.hid_dim), self.seed)
            save_path = os.path.join(file_path, filename)
            # if save_path already exist, clear all existing contents
            if (epoch == 1) and (os.path.exists(save_path)): 
                os.remove(save_path) 
            with open(save_path, 'a') as f:
                f.write('%d %.8f %.8f'%(epoch, train_loss, test_loss))
                f.write("\n")

            if train_loss_min > train_loss:
                train_loss_min = train_loss
                cnt_wait = 0
            else:
                cnt_wait += 1
                if cnt_wait == patience:
                    print('-' * 100)
                    print('Early stopping at '+str(epoch) +' eopch!')
                    break
            print(cnt_wait)

        folder_path = f"./dataspace/pre_trained_model/{self.dataset_name}"
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        torch.save(self.gnn.state_dict(),
                    "{}/{}.{}.{}.pth".format(folder_path, 'SimGRACE', self.gnn_type, str(self.hid_dim) + 'hidden_dim'))
        print("+++model saved ! {}/{}.{}.{}.pth".format(self.dataset_name, 'SimGRACE', self.gnn_type, str(self.hid_dim) + 'hidden_dim'))
