import argparse
import os
import pickle

import math
import networkx as nx
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import torch.nn.functional as F
import numpy as np
import random
import torch.nn as nn
from torch_geometric.utils import from_networkx
from tqdm import tqdm
import matplotlib.pyplot as plt

dir_path = os.path.dirname(os.path.realpath(__file__))
parent_path = os.path.abspath(os.path.join(dir_path, os.pardir))


class GCN(torch.nn.Module):
    def __init__(self, feat_dim, m=10, activation='relu', q=3):
        super().__init__()
        self.activation = activation
        self.feat_dim = feat_dim
        self.q = q
        self.conv_pos = GCNConv(feat_dim, m, bias=False)
        self.conv_neg = GCNConv(feat_dim, m, bias=False)

    def act(self, input):
        if self.activation == 'linear':
            return input
        if self.activation == 'pow_relu':
            return torch.pow(F.relu(input), self.q)
        return F.relu(input)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        out_pos = torch.mean(self.act(self.conv_pos(x[:, :self.feat_dim], edge_index)), dim=1, keepdim=True) \
                  + torch.mean(self.act(self.conv_pos(x[:, self.feat_dim:], edge_index)), dim=1, keepdim=True)
        out_neg = torch.mean(self.act(self.conv_neg(x[:, :self.feat_dim], edge_index)), dim=1, keepdim=True) \
                  + torch.mean(self.act(self.conv_neg(x[:, self.feat_dim:], edge_index)), dim=1, keepdim=True)

        out = out_pos - out_neg
        return torch.sigmoid(out)


def synthetic_graph(sizes=[1000, 1000], porbs=[[0.5, 0.1], [0.1, 0.5]], feat_dim=100, noise_scaling=20.5, feature_core=None):
    G = nx.stochastic_block_model(sizes, porbs, seed=42)
    G.graph = {}
    total_node_num = sum(sizes)
    n_train = int(0.2 * total_node_num)
    n_test = total_node_num - n_train
    if feature_core is None:
        feature_core = torch.randn(1, feat_dim).repeat(total_node_num, 1)
    else:
        feature_core = feature_core[0].reshape(1,-1).cpu().repeat(total_node_num, 1)
    feature_core[sizes[0]:, :] = feature_core[sizes[0]:, :] * -1
    feature_noise = torch.randn(total_node_num, feat_dim) * noise_scaling
    feature = torch.cat([feature_core, feature_noise], dim=1)
    train_idxs = torch.LongTensor(sorted(random.sample(range(total_node_num), n_train)))
    test_idxs = torch.LongTensor(list(set(list(range(total_node_num))) - set(train_idxs.tolist())))
    train_y = torch.LongTensor([1 if train_idx < sizes[0] else 0 for train_idx in train_idxs])
    test_y = torch.LongTensor([1 if test_idx < sizes[0] else 0 for test_idx in test_idxs])
    label = torch.LongTensor([0 for i in range(total_node_num)])
    label[train_idxs] = train_y
    label[test_idxs] = test_y
    data = Data(x=feature, edge_index=from_networkx(G).edge_index)
    empty_data = Data(x=feature, edge_index=torch.tensor([[], []]).long())
    print(data.has_isolated_nodes())
    return data, empty_data, label, train_idxs, test_idxs, feature_core, feature_noise


def feat_noise_tracking(epoch, weight_m, feature_core, feature_noise, feat_tracker, noise_tracker):
    temp_ = torch.matmul(weight_m, feature_core[0].unsqueeze(0).t())
    feat_tracker[0].append([epoch, torch.max((temp_)[:, 0]).item()])
    feat_tracker[1].append([epoch, torch.mean((temp_)[:, 0]).item()])
    feat_tracker[2].append([epoch, torch.min((temp_)[:, 0]).item()])
    temp_ = torch.matmul(weight_m, feature_noise.t())
    noise_tracker[0].append([epoch, torch.max(temp_, dim=0)[0].mean().item()])
    noise_tracker[1].append([epoch, torch.mean(temp_, dim=0)[0].mean().item()])
    noise_tracker[2].append([epoch, torch.min(temp_, dim=0)[0].mean().item()])
    noise_tracker[3].append([epoch, torch.max(temp_, dim=0)[0].max().item()])
    noise_tracker[4].append([epoch, torch.mean(temp_, dim=0)[0].max().item()])
    noise_tracker[5].append([epoch, torch.min(temp_, dim=0)[0].max().item()])
    return feat_tracker, noise_tracker


def train_model(model, data, label, train_idxs, test_idxs, optimizer, feature_core, feature_noise, args):
    model.train()
    train_loss_list, test_loss_list = [], []
    train_acc_list, test_acc_list = [], []
    loss_fn = nn.BCELoss()
    w_pos_feat_tracker = [[], [], []]
    w_pos_noise_tracker = [[], [], [], [], [], []]
    w_neg_feat_tracker = [[], [], []]
    w_neg_noise_tracker = [[], [], [], [], [], []]
    for epoch in tqdm(range(args.epochs)):
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out[train_idxs], label[train_idxs])
        loss.backward()
        optimizer.step()
        train_loss_list.append([epoch, loss.item()])
        model.eval()
        with torch.no_grad():
            out = model(data)
            test_loss = loss_fn(out[test_idxs], label[test_idxs])
            test_loss_list.append([epoch, test_loss.item()])
            pred = out.clone().squeeze()
            pred[pred <= 0.5] = 0
            pred[pred > 0.5] = 1
            pred = pred.long()
            train_correct = pred[train_idxs].eq(label[train_idxs].squeeze().long()).sum().item()
            test_correct = pred[test_idxs].eq(label[test_idxs].squeeze().long()).sum().item()
            train_acc = train_correct / len(train_idxs)
            test_acc = test_correct / len(test_idxs)
            train_acc_list.append([epoch, train_acc])
            test_acc_list.append([epoch, test_acc])
            # feature_core[0] is the orig feat core, feature_core[-1] is the orig feat core * -1
            w_pos_feat_tracker, w_pos_noise_tracker = \
                feat_noise_tracking(epoch, model.conv_pos.lin.weight, feature_core, feature_noise,
                                    w_pos_feat_tracker, w_pos_noise_tracker)
            w_neg_feat_tracker, w_neg_noise_tracker = \
                feat_noise_tracking(epoch, model.conv_neg.lin.weight, feature_core, feature_noise,
                                    w_neg_feat_tracker, w_neg_noise_tracker)

    return np.array([train_loss_list, test_loss_list, train_acc_list, test_acc_list,
                     w_pos_feat_tracker, w_pos_noise_tracker, w_neg_feat_tracker, w_neg_noise_tracker])


training_result_dict = {'Train Loss': 0,
                        'Test Loss': 1,
                        'Train Accuracy': 2,
                        'Test Accuracy': 3,
                        'Positive Feature Learning': 4,
                        'Positive Noise Learning': 5,
                        'Negative Feature Learning': 6,
                        'Negative Noise Learning': 7, }

y_label_dict = {'Train Loss': 'Loss',
                'Test Loss': 'Loss',
                'Train Accuracy': 'Accuracy',
                'Test Accuracy': 'Accuracy',
                'Positive Feature Learning': 'Logit',
                'Positive Noise Learning': 'Logit',
                'Negative Feature Learning': 'Logit',
                'Negative Noise Learning': 'Logit', }

color_dict = {'Max': 'b', 'Mean': 'g', 'Min': 'r',
              'Max_Mean': 'c', 'Mean_Mean': 'm', 'Min_Mean': 'y',
              'Max_Max': 'k', 'Mean_Max': 'w', 'Min_Max': '#eeefff'}

feat_tracker_dict = {'Max': 0, 'Mean': 1, 'Min': 2}

noise_tracker_dict = {'Max_Mean': 0, 'Mean_Mean': 1, 'Min_Mean': 2,
                      'Max_Max': 3, 'Mean_Max': 4, 'Min_Max': 5}



def draw_sns_heatmap(grid_info):
    save_dir = dir_path + '/img_vis/'
    import seaborn as sns
    grid_info_array = np.array(grid_info)
    x_labels = ["{:.2e}".format(i) for i in grid_info_array[:,1].tolist()][:8]
    y_labels = ["{}".format(int(i)) for i in sorted(set(grid_info_array[:,0].tolist()))]
    x_labels.reverse()
    y_labels.reverse()
    assert len(x_labels) == len(y_labels)
    gnn_matrix, cnn_matrix = [], []
    counter = 0
    for y in y_labels:
        for x in x_labels:
            print( (y,grid_info_array[counter][0]) )
            print( (x,grid_info_array[counter][1]) )
            gnn_matrix.append(grid_info_array[counter][2])
            cnn_matrix.append(grid_info_array[counter][3])
            counter += 1
    gnn_matrix = np.flip(np.flip(np.array(gnn_matrix).reshape(8,8), axis=0), axis=1)
    cnn_matrix = np.flip(np.flip(np.array(cnn_matrix).reshape(8,8), axis=0), axis=1)

    plt.figure(figsize=(9, 7))
    ax = sns.heatmap(gnn_matrix, cmap='viridis', annot=True, linewidths=0.5, xticklabels=x_labels, yticklabels=y_labels)
    # Set x and y axis labels
    ax.set_xlabel('SNR')
    # ax.set_xticks(np.arange(8))
    # ax.set_xticklabels(x_labels, rotation=30, ha='center')
    ax.set_ylabel('Sample Size')
    # Show the plot
    plt.savefig(save_dir + 'gnn_heatmap.png')
    plt.close()

    plt.figure()
    ax = sns.heatmap(cnn_matrix, cmap='viridis', annot=True, linewidths=0.5, xticklabels=x_labels, yticklabels=y_labels)
    # Set x and y axis labels
    ax.set_xlabel('SNR')
    ax.set_ylabel('Sample Size')
    # Show the plot
    plt.savefig(save_dir + 'cnn_heatmap.png')
    plt.close()



def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--feat_dim', type=int, default=100)
    parser.add_argument('--learning_rate', type=float, default=0.0005)
    parser.add_argument('--epochs', type=int, default=200)
    # graph generation config
    parser.add_argument('--node_num', type=int, default=2000)
    parser.add_argument('--noise_scale', type=float, default=10)
    parser.add_argument('--edge_prob_str', type=str, default='0.21,0.08,0.08,0.18')
    args = parser.parse_args()
    args.edge_prob = np.array([float(i) for i in args.edge_prob_str.split(',')]).reshape(2, 2).tolist()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return args

def main():
    args = args_parser()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # feature_core = None
    # # synthetic_graph(sizes=[1000,1000], porbs=[[0.5, 0.1], [0.1, 0.5]], feat_dim=100, noise_scaling=20.5)
    # run_num = 0
    # grid_info = []
    # for _num_node in range(200, 8200, 1000):
    #     for _noise_scale in range(0, 15, 2):
    #         _noise_scale = math.exp(_noise_scale // 1.2) # _noise_scale * 3 + 1
    #         # data, empty_data, label, train_idxs, test_idxs, feature_core, feature_noise
    #         data, empty_data, label, train_idxs, test_idxs, feature_core, feature_noise = synthetic_graph(
    #             sizes=[int(_num_node / 2), _num_node - int(_num_node / 2)],
    #             porbs=args.edge_prob,
    #             feat_dim=args.feat_dim,
    #             noise_scaling=_noise_scale,
    #             feature_core=feature_core
    #         )
    #
    #         SNR = torch.norm(feature_core[0], p=2).item() / (math.sqrt(args.feat_dim) * _noise_scale)
    #         graph_info_str = '[{}] Node Num: {}, Edge Prob: {}, Noise Scale: {}'.format(run_num,
    #                                                 _num_node, args.edge_prob_str, _noise_scale)
    #         print(graph_info_str)
    #         run_num += 1
    #         data = data.to(args.device)
    #         empty_data = empty_data.to(args.device)
    #         feature_core = torch.tensor(feature_core).to(args.device)
    #         feature_noise = torch.tensor(feature_noise).to(args.device)
    #         label = label.to(args.device).unsqueeze(1).float()
    #
    #         args.epochs = 200
    #         gnn_model = GCN(feat_dim=args.feat_dim).to(args.device)
    #         gnn_optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.learning_rate, weight_decay=5e-4)
    #         gnn_result_cache = train_model(gnn_model, data, label, train_idxs, test_idxs, gnn_optimizer, feature_core, feature_noise, args)
    #
    #         args.epochs = 1000
    #         cnn_model = GCN(feat_dim=args.feat_dim).to(args.device)
    #         cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=args.learning_rate, weight_decay=5e-4)
    #         cnn_result_cache = train_model(cnn_model, empty_data, label, train_idxs, test_idxs, cnn_optimizer, feature_core, feature_noise, args)
    #
    #         grid_info.append([_num_node, SNR,
    #                           np.array(gnn_result_cache[training_result_dict['Test Accuracy']])[:, 1].max(),
    #                           np.array(cnn_result_cache[training_result_dict['Test Accuracy']])[:, 1].max()]
    #                          )

    # with open(dir_path + '/grid_info.pkl', 'wb') as f:
    #     pickle.dump(grid_info, f)

    with open(dir_path + '/grid_info.pkl', 'rb') as f:
        grid_info = pickle.load(f)

    draw_sns_heatmap(grid_info)



if __name__ == '__main__':
    main()