import argparse
import os

import networkx as nx
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import torch.nn.functional  as F
import numpy as np
import random
import torch.nn as nn
from torch_geometric.utils import from_networkx
from tqdm import tqdm
import matplotlib.pyplot as plt


dir_path = os.path.dirname(os.path.realpath(__file__))
parent_path = os.path.abspath(os.path.join(dir_path, os.pardir))



class GCN(torch.nn.Module):
    def __init__(self, feat_dim, m=10, activation='relu', q=3):
        super().__init__()
        self.activation = activation
        self.feat_dim = feat_dim
        self.q = q
        self.conv_pos = GCNConv(feat_dim, m, bias=False)
        self.conv_neg = GCNConv(feat_dim, m, bias=False)

    def act(self, input):
        if self.activation == 'linear':
            return input
        if self.activation == 'pow_relu':
            return torch.pow(F.relu(input),self.q)
        return F.relu(input)


    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        out_pos = torch.mean(self.act(self.conv_pos(x[:, :self.feat_dim], edge_index)), dim=1, keepdim=True) \
                  + torch.mean(self.act(self.conv_pos(x[:, self.feat_dim:], edge_index)), dim=1, keepdim=True)
        out_neg = torch.mean(self.act(self.conv_neg(x[:, :self.feat_dim], edge_index)), dim=1, keepdim=True) \
                  + torch.mean(self.act(self.conv_neg(x[:, self.feat_dim:], edge_index)), dim=1, keepdim=True)

        out = out_pos - out_neg
        return torch.sigmoid(out)



def synthetic_graph(sizes=[1000,1000], porbs=[[0.5, 0.1], [0.1, 0.5]], feat_dim=100, noise_scaling=20.5):
    G = nx.stochastic_block_model(sizes, porbs, seed=42)
    G.graph = {}
    total_node_num = sum(sizes)
    n_train = int(0.2 * total_node_num)
    n_test = total_node_num - n_train
    feature_core = torch.randn(1, feat_dim).repeat(total_node_num, 1)
    feature_core[sizes[0]:, :] = feature_core[sizes[0]:, :] * -1
    feature_noise = torch.randn(total_node_num, feat_dim) * noise_scaling
    feature = torch.cat([feature_core, feature_noise],dim=1)
    train_idxs = torch.LongTensor(sorted(random.sample(range(total_node_num), n_train)))
    test_idxs = torch.LongTensor(list(set(list(range(total_node_num))) - set(train_idxs.tolist())))
    train_y = torch.LongTensor([1 if train_idx < sizes[0] else 0 for train_idx in train_idxs])
    test_y = torch.LongTensor([1 if test_idx < sizes[0] else 0 for test_idx in test_idxs])
    label = torch.LongTensor([0 for i in range(total_node_num)])
    label[train_idxs] = train_y
    label[test_idxs] = test_y
    data = Data(x=feature, edge_index=from_networkx(G).edge_index)
    empty_data = Data(x=feature, edge_index=torch.tensor([[], []]).long())
    print(data.has_isolated_nodes())
    return data, empty_data, label, train_idxs, test_idxs, feature_core, feature_noise



def curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title):
    save_dir = dir_path + '/img_vis/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)


    if 'Feature Learning' in title:
        gnn_info = gnn_result_cache[training_result_dict[title]]
        cnn_info = cnn_result_cache[training_result_dict[title]]
        for _type in ['Max', 'Mean', 'Min']:
            plt.figure()
            _gnn_info = gnn_info[feat_tracker_dict[_type]]
            _cnn_info = cnn_info[feat_tracker_dict[_type]]
            plt.plot([x[0] for x in _gnn_info],
                     [x[1] for x in _gnn_info], 'b', linestyle='-', label='GNN ' + title)
            plt.plot([x[0] for x in _cnn_info],
                     [x[1] for x in _cnn_info], 'r', linestyle='--', label='CNN ' + title)
            plt.xlabel('Epochs')
            plt.ylabel(y_label_dict[title])
            plt.title(title + ' {}'.format(_type)+ '\n{}'.format(graph_info_str))
            plt.legend()
            # Set grid
            plt.grid(True, linestyle='-')
            # Set grey background
            plt.gca().set_facecolor('#f9f9f9')
            plt.savefig(save_dir + '_'.join(title.split(' ')) + '_{}_curve.png'.format(_type))
            plt.close()

    elif 'Noise Learning' in title:
        gnn_info = gnn_result_cache[training_result_dict[title]]
        cnn_info = cnn_result_cache[training_result_dict[title]]
        for _type in ['Max_Mean', 'Mean_Mean', 'Min_Mean', 'Max_Max', 'Mean_Max', 'Min_Max']:
            _gnn_info = gnn_info[noise_tracker_dict[_type]]
            _cnn_info = cnn_info[noise_tracker_dict[_type]]
            plt.figure()
            plt.plot([x[0] for x in _gnn_info],
                     [x[1] for x in _gnn_info], 'b', linestyle='-', label='GNN ' + title)
            plt.plot([x[0] for x in _cnn_info],
                     [x[1] for x in _cnn_info], 'r', linestyle='--', label='CNN ' + title)
            plt.xlabel('Epochs')
            plt.ylabel(y_label_dict[title])
            plt.title(title + ' {}'.format(' '.join(_type.split('_')))+ '\n{}'.format(graph_info_str))
            plt.legend()
            # Set grid
            plt.grid(True, linestyle='-')
            # Set grey background
            plt.gca().set_facecolor('#f9f9f9')
            plt.savefig(save_dir + '_'.join(title.split(' ')) + '_{}_curve.png'.format(_type))
            plt.close()

    else:
        plt.figure()
        gnn_info = gnn_result_cache[training_result_dict[title]]
        cnn_info = cnn_result_cache[training_result_dict[title]]
        plt.plot([x[0] for x in gnn_info], [x[1] for x in gnn_info], 'b', label='GNN ' + title)
        plt.plot([x[0] for x in cnn_info], [x[1] for x in cnn_info], 'r', label='CNN ' + title)
        plt.xlabel('Epochs')
        plt.ylabel(y_label_dict[title])
        plt.title(title + '\n{}'.format(graph_info_str))
        plt.legend()
        # Set grid
        plt.grid(True, linestyle='-')
        # Set grey background
        plt.gca().set_facecolor('#f9f9f9')
        plt.savefig(save_dir + '_'.join(title.split(' ')) + '_curve.png')
        plt.close()


def feat_noise_tracking(epoch, weight_m, feature_core, feature_noise, feat_tracker, noise_tracker):
    temp_ = torch.matmul(weight_m, feature_core[0].unsqueeze(0).t())
    feat_tracker[0].append([epoch, torch.max((temp_)[:, 0]).item()])
    feat_tracker[1].append([epoch, torch.mean((temp_)[:, 0]).item()])
    feat_tracker[2].append([epoch, torch.min((temp_)[:, 0]).item()])
    temp_ = torch.matmul(weight_m, feature_noise.t())
    noise_tracker[0].append([epoch, torch.max(temp_, dim=0)[0].mean().item()])
    noise_tracker[1].append([epoch, torch.mean(temp_, dim=0)[0].mean().item()])
    noise_tracker[2].append([epoch, torch.min(temp_, dim=0)[0].mean().item()])
    noise_tracker[3].append([epoch, torch.max(temp_, dim=0)[0].max().item()])
    noise_tracker[4].append([epoch, torch.mean(temp_, dim=0)[0].max().item()])
    noise_tracker[5].append([epoch, torch.min(temp_, dim=0)[0].max().item()])
    return feat_tracker, noise_tracker



def train_model(model, data, label, train_idxs, test_idxs, optimizer, feature_core, feature_noise, args):
    model.train()
    train_loss_list, test_loss_list = [], []
    train_acc_list, test_acc_list = [], []
    loss_fn = nn.BCELoss()
    w_pos_feat_tracker = [[],[],[]]
    w_pos_noise_tracker = [[],[],[],[],[],[]]
    w_neg_feat_tracker = [[],[],[]]
    w_neg_noise_tracker = [[],[],[],[],[],[]]
    for epoch in tqdm(range(args.epochs)):
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out[train_idxs], label[train_idxs])
        loss.backward()
        optimizer.step()
        train_loss_list.append([epoch, loss.item()])
        model.eval()
        with torch.no_grad():
            out = model(data)
            test_loss = loss_fn(out[test_idxs], label[test_idxs])
            test_loss_list.append([epoch, test_loss.item()])
            pred = out.clone().squeeze()
            pred[pred <= 0.5] = 0
            pred[pred > 0.5] = 1
            pred = pred.long()
            train_correct = pred[train_idxs].eq(label[train_idxs].squeeze().long()).sum().item()
            test_correct = pred[test_idxs].eq(label[test_idxs].squeeze().long()).sum().item()
            train_acc = train_correct / len(train_idxs)
            test_acc = test_correct / len(test_idxs)
            train_acc_list.append([epoch, train_acc])
            test_acc_list.append([epoch, test_acc])
            # feature_core[0] is the orig feat core, feature_core[-1] is the orig feat core * -1
            w_pos_feat_tracker, w_pos_noise_tracker = \
                feat_noise_tracking(epoch, model.conv_pos.lin.weight, feature_core, feature_noise,
                                    w_pos_feat_tracker, w_pos_noise_tracker)
            w_neg_feat_tracker, w_neg_noise_tracker = \
                feat_noise_tracking(epoch, model.conv_neg.lin.weight, feature_core, feature_noise,
                                    w_neg_feat_tracker, w_neg_noise_tracker)

    return np.array([train_loss_list, test_loss_list, train_acc_list, test_acc_list,
              w_pos_feat_tracker, w_pos_noise_tracker, w_neg_feat_tracker, w_neg_noise_tracker])


training_result_dict = {'Train Loss':0,
                        'Test Loss':1,
                        'Train Accuracy':2,
                        'Test Accuracy':3,
                        'Positive Feature Learning':4,
                        'Positive Noise Learning':5,
                        'Negative Feature Learning':6,
                        'Negative Noise Learning':7,}

y_label_dict = {'Train Loss':'Loss',
                'Test Loss':'Loss',
                'Train Accuracy':'Accuracy',
                'Test Accuracy':'Accuracy',
                'Positive Feature Learning':'Logit',
                'Positive Noise Learning':'Logit',
                'Negative Feature Learning':'Logit',
                'Negative Noise Learning':'Logit',}

color_dict={'Max':'b', 'Mean':'g', 'Min':'r',
            'Max_Mean':'c', 'Mean_Mean':'m', 'Min_Mean':'y',
            'Max_Max':'k', 'Mean_Max':'w', 'Min_Max':'#eeefff'}

feat_tracker_dict = {'Max':0, 'Mean':1, 'Min':2}

noise_tracker_dict = {'Max_Mean':0, 'Mean_Mean':1, 'Min_Mean':2,
                        'Max_Max':3, 'Mean_Max':4, 'Min_Max':5}



def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--feat_dim', type=int, default=100)
    parser.add_argument('--learning_rate', type=float, default=0.0005)
    parser.add_argument('--epochs', type=int, default=1000)
    # graph generation config
    parser.add_argument('--node_num', type=int, default=8000)
    parser.add_argument('--noise_scale', type=float, default=20)
    parser.add_argument('--edge_prob_str', type=str,default='0.21,0.08,0.08,0.18')
    args = parser.parse_args()
    args.edge_prob = np.array([float(i) for i in args.edge_prob_str.split(',')]).reshape(2,2).tolist()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return args




def main():
    args = args_parser()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # synthetic_graph(sizes=[1000,1000], porbs=[[0.5, 0.1], [0.1, 0.5]], feat_dim=100, noise_scaling=20.5)
    data, empty_data, label, train_idxs, test_idxs, feature_core, feature_noise = synthetic_graph(
        sizes=[int(args.node_num/2), args.node_num-int(args.node_num/2)],
        porbs=args.edge_prob,
        feat_dim=args.feat_dim,
        noise_scaling=args.noise_scale
    )

    graph_info_str = 'Node Num: {}, Edge Prob: {}, Noise Scale: {}'.format(args.node_num, args.edge_prob_str, args.noise_scale)

    data = data.to(args.device)
    empty_data = empty_data.to(args.device)
    feature_core = torch.tensor(feature_core).to(args.device)
    feature_noise = torch.tensor(feature_noise).to(args.device)
    label = label.to(args.device).unsqueeze(1).float()


    cnn_model = GCN(feat_dim=args.feat_dim).to(args.device)
    cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=args.learning_rate, weight_decay=5e-4)
    cnn_result_cache = train_model(cnn_model, empty_data, label, train_idxs, test_idxs, cnn_optimizer, feature_core, feature_noise, args)

    gnn_model = GCN(feat_dim=args.feat_dim).to(args.device)
    gnn_optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.learning_rate, weight_decay=5e-4)
    gnn_result_cache = train_model(gnn_model, data, label, train_idxs, test_idxs, gnn_optimizer, feature_core, feature_noise, args)


    curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title='Positive Feature Learning')
    curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title='Positive Noise Learning')
    curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title='Negative Feature Learning')
    curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title='Negative Noise Learning')

    curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title='Train Loss')
    curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title='Test Loss')
    curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title='Train Accuracy')
    curve_visualization(gnn_result_cache, cnn_result_cache, graph_info_str, title='Test Accuracy')

if __name__ == '__main__':
    main()