import os
import shutil
import random
import argparse
import datetime
import numpy as np
import pandas as pd
import networkx as nx
from collections import Counter
from typing import Any, Optional, Tuple
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from sklearn import metrics
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.autograd import Variable
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch.utils.data import TensorDataset, DataLoader
import warnings
warnings.filterwarnings("ignore")

class GCN(nn.Module):
    def __init__(self, feature_dims, out_dims, hidden_dims):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(feature_dims, hidden_dims)
        self.bn = nn.BatchNorm1d(hidden_dims)
        self.relu = nn.ReLU()
        self.conv2 = GCNConv(hidden_dims, out_dims)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        return x

class GateFusion(nn.Module):
    def __init__(self, size_in1, size_in2, size_out=16):
        super(GateFusion, self).__init__()
        self.size_in1, self.size_in2, self.size_out = size_in1, size_in2, size_out
        self.hidden1 = nn.Linear(size_in1, size_out, bias=False)
        self.hidden2 = nn.Linear(size_in2, size_out, bias=False)
        self.hidden_sigmoid = nn.Linear(size_out * 2, 1, bias=False)
        self.tanh_f = nn.Tanh()
        self.sigmoid_f = nn.Sigmoid()

    def forward(self, x1, x2):
        h1 = self.tanh_f(self.hidden1(x1))
        h2 = self.tanh_f(self.hidden1(x2))
        x = torch.cat((x1, x2), dim=1)
        z = self.sigmoid_f(self.hidden_sigmoid(x))
        return z.view(z.size()[0], 1) * h1 + (1 - z).view(z.size()[0], 1) * h2

class GradReverse(torch.autograd.Function):

    @staticmethod
    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
        ctx.coeff = coeff
        output = input * 1.0
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        return grad_output.neg() * ctx.coeff, None

def grad_reverse(x, coeff):
    return GradReverse.apply(x, coeff)


class Adversarial(nn.Module):
    def __init__(self, embedding_dim, layer_number, gcn_data):
        super(Adversarial, self).__init__()

        self.layer_number = layer_number
        self.node_dim = embedding_dim
        self.edge_dim = embedding_dim * 2
        self.gcn_data = gcn_data

        for i in range(self.layer_number):
            gcn = GCN(feature_dims=1, out_dims=self.node_dim, hidden_dims=64)
            setattr(self, 'gcn%i' % i, gcn)

        self.target_conv = nn.Sequential(
            nn.Conv1d(in_channels=self.edge_dim, out_channels=self.edge_dim, kernel_size=1),
            nn.ReLU())

        self.generality_conv = nn.Sequential(
            nn.Conv1d(in_channels=self.edge_dim, out_channels=self.edge_dim, kernel_size=1),
            nn.ReLU())

        self.fusion_gate = GateFusion(self.edge_dim, self.edge_dim, self.edge_dim)

        self.link_classifier = nn.Sequential(
            nn.Linear(self.edge_dim, 20),
            nn.ReLU(),
            nn.Softmax(dim=1))

        self.network_classifier = nn.Sequential(
            nn.Linear(self.edge_dim, 20),
            nn.ReLU(),
            nn.Linear(20, self.layer_number),
            nn.Softmax(dim=1))


    def forward(self, now_layer, leftnode, rightnode, coeff=1):

        for i in range(self.layer_number):
            layer_embed = eval('self.gcn'+str(i))(self.gcn_data[i]).cuda()
            setattr(self, 'layer%i' % i, layer_embed)
        layer_names = ['self.layer'+str(i) for i in now_layer.cpu().numpy().tolist()]
        edge = torch.Tensor().cuda()
        for (l, i, j) in zip(layer_names, leftnode, rightnode):
            temp = torch.cat((eval(l)[i], eval(l)[j]), dim=0).cuda()
            temp = torch.unsqueeze(temp, dim=0)
            edge = torch.cat((edge, temp), dim=0)

        # generality
        generality_embed = self.generality_conv(edge.permute(1, 0)).permute(1, 0)
        # 网络识别
        reverse_feature = grad_reverse(generality_embed, coeff)
        network_output = self.network_classifier(reverse_feature)
        # target
        target_embed = self.target_conv(edge.permute(1, 0)).permute(1, 0)
        # 特征结合
        fusion_embed = self.fusion_gate(target_embed, generality_embed)
        # 链路预测
        link_output = self.link_classifier(fusion_embed)
        return link_output, network_output

    def metrics_eval(self, eval_data):
        scores = []
        labels = []
        preds = []
        for data in eval_data:
            network_labels, left_nodes, right_nodes, link_labels = data
            with torch.no_grad():  # 不计算参数梯度
                network_labels = Variable(network_labels).cuda()
                left_nodes = Variable(left_nodes).cuda()
                right_nodes = Variable(right_nodes).cuda()
                link_labels = Variable(link_labels).cuda()
            output, _ = self.forward(network_labels, left_nodes, right_nodes)
            _, argmax = torch.max(output, 1)
            scores += list(output[:, 1].cpu().detach().numpy())
            labels += list(link_labels.cpu().detach().numpy())
            preds += list(argmax.cpu().detach().numpy())
        # 二元分类
        acc = metrics.accuracy_score(labels, preds)
        auc = metrics.roc_auc_score(labels, scores, average=None)
        return acc*100, auc*100

def run_Adversarial_Model(train_loader, valid_data, test_data, model, initial_learning_rate, epochs, record_file):
    model_path = 'result/model/'
    if os.path.exists(model_path):  # 清除之前运行代码生成的模型
        shutil.rmtree(model_path)
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    outfile = open(record_file, 'a', encoding='utf-8')
    print(datetime.datetime.now())
    print('Training...')
    best_valid_dir = ''
    best_valid = 0
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs + 1):
        p = epoch / epochs
        learning_rate = initial_learning_rate / pow((1 + 10 * p), 0.75)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.0001)

        # 测试集
        model.train()
        link_loss_vec = []
        network_loss_vec = []
        whole_loss_vec = []
        train_acc_vec = []
        for data in train_loader:
            network_labels, left_nodes, right_nodes, link_labels = data
            network_labels = Variable(network_labels).cuda()
            left_nodes = Variable(left_nodes).cuda()
            right_nodes = Variable(right_nodes).cuda()
            link_labels = Variable(link_labels).cuda()
            # 向前传播
            link_outs, network_output = model(network_labels, left_nodes, right_nodes)
            link_loss = criterion(link_outs, link_labels)
            network_loss = criterion(network_output, network_labels)
            batch_loss = link_loss + network_loss
            whole_loss_vec.append(batch_loss.cpu().detach().numpy())
            link_loss_vec.append(link_loss.cpu().detach().numpy())
            network_loss_vec.append(network_loss.cpu().detach().numpy())

            _, argmax = torch.max(link_outs, 1)
            batch_acc = (argmax == link_labels).float().mean()
            train_acc_vec.append(batch_acc.item())

            optimizer.zero_grad()  # 清空过往梯度
            torch.autograd.set_detect_anomaly = True
            batch_loss.backward(retain_graph=True)  # 反向传播,计算当前梯度 retain_graph=True梯度保存
            optimizer.step()  # 梯度下降,更新网络参数

        link_loss = np.mean(link_loss_vec)
        network_loss = np.mean(network_loss_vec)
        whole_loss = np.mean(whole_loss_vec)
        train_acc = np.mean(train_acc_vec)

        print('epoch:[{}/{}], lr:{:.4f}, whole_loss:{:.4f}, pred_loss:{:.4f}, network_loss:{:.4f}'.format(
                epoch, epochs, learning_rate, whole_loss, link_loss, network_loss))
        print('Train acc:{:.2f}'.format(train_acc*100))

        # 验证集
        model.eval()
        valid_acc, valid_auc = model.metrics_eval(valid_data)
        print("Valid acc:{:.2f}, auc:{:.2f} ".format(valid_acc, valid_auc))
        print("=================================")
        # 保存最好模型
        if valid_auc > best_valid:
            best_valid = valid_auc
            best_valid_dir = model_path + 'model' + str(epoch) + '.pkl'
            torch.save(model, best_valid_dir)

    # 测试
    model.eval()
    # 加载最好模型
    print('Load best model ...')
    print(best_valid_dir)
    model = torch.load(best_valid_dir)
    acc, auc = model.metrics_eval(test_data)
    write_infor = "\nTest acc:{:.2f}, auc:{:.2f} ".format(acc, auc)
    print(write_infor)
    outfile.write(write_infor)
    outfile.close()


def load_model(train_loader, valid_loader, test_loader, gcn_data, network_numbers, args):
    for i in range(network_numbers):
        gcn_data[i].x = gcn_data[i].x.cuda()
        gcn_data[i].edge_index = gcn_data[i].edge_index.cuda()
    model = Adversarial(embedding_dim=args.dim_node, layer_number=network_numbers, gcn_data=gcn_data)
    model = model.cuda()
    run_Adversarial_Model(train_loader, valid_loader, test_loader, model, args.learning_rate, args.epochs, args.record_file)

def obtain_sample(inter, all_nodes, network_layer):
    sample = []
    group = inter.groupby('left')
    for node, inter in group:
        pos_list = inter['right'].tolist()
        for temp_node in all_nodes:
            if temp_node == node:
                continue
            if temp_node in pos_list:
                sample.append([network_layer, node, temp_node, 1])
            else:
                sample.append([network_layer, node, temp_node, 0])
    return sample

def get_loader(infor, batch_size):
    network = torch.LongTensor(infor[:, 0])
    leftnode = torch.LongTensor(infor[:, 1])
    rightnode = torch.LongTensor(infor[:, 2])
    link = torch.LongTensor(infor[:, 3])
    data_set = TensorDataset(network, leftnode, rightnode, link)
    data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True)
    return data_loader

def gcndata_load(inters, all_nodes):
    pos_edge = np.array(inters).tolist()
    g = nx.Graph(pos_edge)  # 交互关系转换为图
    g.add_nodes_from(all_nodes)
    adj = nx.to_scipy_sparse_matrix(g, nodelist=all_nodes, dtype=int, format='coo')  # 生成图的邻接矩阵的稀疏矩阵
    edge_index = torch.LongTensor(np.vstack((adj.row, adj.col)))  # 得到gcn需要的coo形式的edge_index
    x = torch.unsqueeze(torch.FloatTensor(all_nodes), 1)
    gcn_data = Data(x=x, edge_index=edge_index)
    return gcn_data

def load_data(dataset, batch_size):
    datadir = 'data/' + dataset + '_data/'
    layerfiles = os.listdir(datadir)
    network_numbers = len(layerfiles)
    change = []
    for i in range(network_numbers):
        now_layer = datadir + dataset + str(i+1) + '.txt'
        now_inter = pd.read_csv(now_layer, sep=' ', header=None)
        change += list(set(np.array(now_inter).reshape(-1)))
    change = list(set(change))
    change_dict = {}
    for i in range(len(change)):
        change_dict[change[i]] = i
    whole_nodes = list(change_dict.values())
    print('Nodes:', len(whole_nodes))
    data = []
    gcn_data = []
    for i in range(network_numbers):
        now_layer = datadir + dataset + str(i+1) + '.txt'
        now_inter = pd.read_csv(now_layer, sep=' ', header=None, names=['left', 'right'])
        print('Edges of layer ' + str(i+1) + ": " + str(now_inter.shape[0]))
        now_inter['left'] = now_inter['left'].map(change_dict)  # ID映射
        now_inter['right'] = now_inter['right'].map(change_dict)  # ID映射
        data += obtain_sample(now_inter, whole_nodes, i)
        gcn_data.append(gcndata_load(now_inter, whole_nodes))
    print('-----------------------')
    data = np.array(data)
    np.random.shuffle(data)
    # 80%train+10%valid+10%test
    train_infor, test_infor, train_label, test_label = train_test_split(data,  data[:, 3], test_size=0.2)
    valid_infor, test_infor, valid_label, test_label = train_test_split(test_infor, test_label, test_size=0.5)
    train_counts = sorted(Counter(train_label).items())
    print("train counter", train_counts)
    if train_counts[1][1] > 10000:  # 正样本足够
        # train欠采样
        under = RandomUnderSampler(sampling_strategy=1)
        train_infor, train_label = under.fit_resample(train_infor, train_label)
        print("train under sampling results: ", sorted(Counter(train_label).items()))
    else:
        # train过采样
        over = RandomOverSampler(sampling_strategy=1)
        train_infor, train_label = over.fit_resample(train_infor, train_label)
        print("train over sampling results: ", sorted(Counter(train_label).items()))
    print("valid counter: ", sorted(Counter(valid_label).items()))
    print("test counter: ", sorted(Counter(test_label).items()))
    train_loader = get_loader(train_infor, batch_size)
    valid_loader = get_loader(valid_infor, batch_size)
    test_loader = get_loader(test_infor, batch_size)
    return train_loader, valid_loader, test_loader, gcn_data, network_numbers

def run_model(args):
    outfile = open(args.record_file, 'a', encoding='utf-8')
    write_infor = "\n\n" + args.model + ", dataset:" + args.dataset + ", lr:{:.4f}, epochs:{}, batch:{}, dim:{}".format(args.learning_rate, args.epochs, args.batch_size, args.dim_node)+'\n'
    outfile.write(write_infor)
    outfile.close()
    for repeat in range(args.repeats):
        print('The program starts running.')
        begin = datetime.datetime.now()
        print('Start time ', begin)
        # load data
        train_loader, valid_loader, test_loader, gcn_data, network_numbers = load_data(args.dataset, args.batch_size)
        # load model
        load_model(train_loader, valid_loader, test_loader, gcn_data, network_numbers, args)
        end = datetime.datetime.now()
        print('End time ', end)
        print('Run time ', end-begin)
        print('\n\n\n\n\n')

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model', '-m', type=str, default='AER', help='name of model')
    parser.add_argument('--dataset', '-d', type=str, default='TF', help='name of dataset')
    parser.add_argument('--learning_rate', '-l', type=float, default='0.01', help='initial learning rate')
    parser.add_argument('--epochs', '-e', type=int, default='50', help='numbers of iterations')
    parser.add_argument('--batch_size', '-b', type=int, default='256', help='batch size of train data')
    parser.add_argument('--dim_node', '-dim', type=int, default='16', help='dims of embedding')
    parser.add_argument('--repeats', '-r', type=int, default='10', help='numbers of repeats')
    parser.add_argument('--record_file', '-f', type=str, default='result/log.txt', help='record file path')

    args = parser.parse_args()
    print(args)
    run_model(args)

