import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn
from dgl.data import TexasDataset,CornellDataset,WisconsinDataset,SquirrelDataset, ActorDataset, ChameleonDataset
from dgl import AddSelfLoop
import scipy
from scipy import sparse
from scipy.io import savemat
import argparse
import numpy as np
import random
import time
import sys
import json
import os

class GCN(nn.Module):
    def __init__(self, in_size, hid_size, out_size, num_layers=1):
        super(GCN, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        # Adding layers based on num_layers specified
        if num_layers == 1:
            self.layers.append(dglnn.GraphConv(in_size, out_size))
        elif num_layers == 2:
            self.layers.append(dglnn.GraphConv(in_size, hid_size))
            self.layers.append(dglnn.GraphConv(hid_size, out_size))

        self.dropout = nn.Dropout(0.5)

    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:  # Apply dropout before all but the first layer
                h = self.dropout(h)
            h = layer(g, h)
        return h


def evaluate(g, features, labels, mask, model):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        # predicted_labels = torch.argmax(logits, dim=-1)
        logits = logits[mask]
        prob_distribution = F.softmax(logits, dim=-1)
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return prob_distribution, correct.item()*1.0 / len(labels)

def train(data, features, labels, train_mask, valid_mask, model,train_epoch, graph_type, prob_lambda, seed):
    # define train/val samples, loss function and optimizer
    # train_mask = masks[0][:,split] # Extacts the train and validation masks from the 'masks' list
    # val_mask = masks[1][:,split] # masks seperate the data into training and validation subsets
    if graph_type == "original":
         g = data
    elif graph_type == "new":
         g = data
         g = generate_ngraph_rand_addedge(g,prob_lambda=prob_lambda, train_mask=train_mask, seed=seed)

    loss_fcn = nn.CrossEntropyLoss() # defines the loss function as cross-entropy loss
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # Initializes the Adam optimizer with the model parameters as the optimization variables.
    # training loop
    for epoch in range(train_epoch):
        model.train() # set the model in the training mode
        logits = model(g, features) # Computes the logits (raw model outputs)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 1 == 0:
            predicted_labels, acc = evaluate(g, features, labels, valid_mask, model)
            print(
                "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f}".format(
                epoch, loss.item(), acc
                )
            )
    return g

def rand_train_test_idx(label, train_prop=.6, valid_prop=.2, ignore_negative=True,seed=None):
    """ randomly splits label into train/valid/test splits """
    if ignore_negative:
        labeled_nodes = torch.where(label != -1)[0]
    else:
        labeled_nodes = label

    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    n = labeled_nodes.shape[0]
    train_num = int(n * train_prop)
    valid_num = int(n * valid_prop)
    perm = torch.as_tensor(np.random.permutation(n), dtype=torch.int64)
    train_indices = perm[:train_num]
    val_indices = perm[train_num:train_num + valid_num]
    test_indices = perm[train_num + valid_num:]
    if not ignore_negative:
        return train_indices, val_indices, test_indices
    train_idx = labeled_nodes[train_indices]
    valid_idx = labeled_nodes[val_indices]
    test_idx = labeled_nodes[test_indices]
    return train_idx, valid_idx, test_idx

# def generate_ngraph_rand_addedge(g, prob_lambda, train_mask):
#     labels = g.ndata['label']
#     device = labels.device
#     # train_idx, valid_idx, test_idx = rand_train_test_idx(labels, train_prop=0.6, valid_prop=0.2, ignore_negative=True,seed=None)
#     #train_mask = g.ndata['train_mask'][:,0]
#     # print(train_idx, valid_idx, test_idx)
#     # train_nodes = train_idx.cpu().numpy()
#     train_nodes = np.where(train_mask.cpu().numpy())[0]
#     print(train_nodes.shape)
#
#
#     p_add_edge = prob_lambda  # probability of adding edge
#
#     new_edges = []
#     for i in train_nodes:
#         for j in train_nodes:
#             if i != j and labels[i] == labels[j] and np.random.rand() < p_add_edge:
#                 new_edges.append((i, j))
#
#     new_g = g.clone()
#     if new_edges:
#         new_edges = np.array(new_edges).T
#         new_g.add_edges(new_edges[0], new_edges[1])
#
#     # print(new_edges.shape)
#     g_simple = dgl.to_simple(new_g.to('cpu'), return_counts='count', writeback_mapping=False)
#     print(g_simple)
#     g_simple = g_simple.to(device)
#     print('edge num increased:', g_simple.num_edges()-g.num_edges())
#     return g_simple

# def generate_ngraph_rand_addedge(g, prob_lambda, train_mask):
#     labels = g.ndata['label']
#     device = labels.device
#
#     # Get indices of training nodes
#     train_nodes = torch.where(train_mask)[0]
#     print(f"Train nodes shape: {train_nodes.shape}")
#
#     p_add_edge = prob_lambda  # probability of adding edge
#
#     # Create a meshgrid of train nodes indices
#     i, j = torch.meshgrid(train_nodes, train_nodes, indexing='ij')
#     i = i.flatten()
#     j = j.flatten()
#
#     # Ensure no self loops and labels match
#     mask = (i != j) & (labels[i] == labels[j])
#
#     # Apply the probability mask
#     random_prob = torch.rand(i.size(), dtype=torch.float, device=device)
#     mask = mask & (random_prob < p_add_edge)
#
#     new_edges = (i[mask], j[mask])
#
#     # Clone the graph and add edges
#     new_g = g.clone()
#     if new_edges[0].size(0) > 0:
#         new_g.add_edges(new_edges[0], new_edges[1])
#
#     # Simplify the graph
#     g_simple = dgl.to_simple(new_g.to('cpu'), return_counts='count', writeback_mapping=False)
#     print(g_simple)
#     g_simple = g_simple.to(device)
#     print(f'Edge number increased by: {g_simple.num_edges() - g.num_edges()}')
#
#     return g_simple

def generate_ngraph_rand_addedge(data, prob_lambda, train_mask, seed):
  labels = data.ndata['label']
  device = labels.device
  labels = labels.cpu().numpy()
  train_nodes = np.where(train_mask.cpu().numpy())[0]
  p_add_edge = prob_lambda  # probability of adding edge

  random.seed(seed)
  np.random.seed(seed)

  # Create a meshgrid of train nodes
  i, j = np.meshgrid(train_nodes, train_nodes, indexing='ij')
  i = i.flatten()
  j = j.flatten()

  # Filter out self-loops and ensure labels are the same
  mask = (i != j) & (labels[i] == labels[j])

  # Apply the probability mask
  random_prob = np.random.rand(i.size)
  mask = mask & (random_prob < p_add_edge)

  # Filter indices based on mask
  new_edges = (i[mask], j[mask])

  # Coalesce the new edges with existing ones
  new_g = data.clone()
  if new_edges[0].size > 0:
     new_g.add_edges(new_edges[0], new_edges[1])
  # Simplify the graph
  g_simple = dgl.to_simple(new_g.to('cpu'), return_counts='count', writeback_mapping=False)
  print(g_simple)
  g_simple = g_simple.to(device)
  print(f'Edge number increased by: {g_simple.num_edges() - data.num_edges()}')
  #
  return g_simple



def compute_label_distribution(graph, labels, num_classes):
    # Initialize an empty list to store the label distribution for each node
    label_distributions = []

    # Number of classes

    # Iterate through each node in the graph
    for node in range(graph.number_of_nodes()):
        # Get the neighbors of the node
        neighbors = graph.in_edges(node)[0]
        # neighbors = list(graph.neighbors(node))
        # Include the node itself in the neighbors list
        # neighbors.append(node)  # 加入节点自己
        # Get the labels of the neighbors
        # print(neighbors)
        neighbor_labels = labels[neighbors]
        # Count the occurrences of each class in the neighborhood
        label_counts = torch.bincount(neighbor_labels, minlength=num_classes)
        # Normalize to get probabilities
        label_distribution = label_counts.float() / label_counts.sum()
        label_distributions.append(label_distribution)

    return torch.stack(label_distributions)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="squirrel"
    )
    parser.add_argument(
        "--graph_type",
        type=str,
        default="original"
    )
    parser.add_argument(
        "--prob_lambda",
        type=float,
        default=0,
        help= "Probability lambda for new graph generation"
    )
    parser.add_argument(
        "--GCN_layer",
        type=int,
        default=1,
        help= "The number of GCN layer"
    )
    parser.add_argument(
        "--train_epoch",
        type=int,
        default=100
    )
    parser.add_argument(
        "--feature_norm",
        type=str,
        default=None
    )
    parser.add_argument("--seed", type=int, default=0, help="The value of the seed")
    args = parser.parse_args()
    # load and precess dataset
    transform = (
        AddSelfLoop()
    )  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == "texas":
        data_raw = TexasDataset(transform=transform)
    elif args.dataset == "cornell":
        data_raw = CornellDataset(transform=transform)
    elif args.dataset == "wisconsin":
        data_raw = WisconsinDataset(transform=transform)
    elif args.dataset == "squirrel":
        data_raw = SquirrelDataset(transform=transform)
    elif args.dataset == "actor":
        data_raw = ActorDataset(transform=transform)
    elif args.dataset == "chameleon":
        data_raw = ChameleonDataset(transform=transform)
    else:
        raise ValueError("Unknown dataset: {}".format(args.dataset))


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = data_raw[0].to(device)
    # feature normalization
    if args.dataset == "actor" or args.dataset == "chameleon" or args.dataset == "squirrel" and args.feature_norm =="Yes":
        features = data.ndata["feat"]
        features = features + 0.1 / (features.shape[1]) * torch.ones_like(data.ndata["feat"])
        norm = features.norm(p=2, dim=1, keepdim=True)
        features = 20*features/norm
    else:
        features = data.ndata["feat"]
    # features = data.ndata["feat"]
    labels = data.ndata["label"]
    masks = data.ndata["train_mask"], data.ndata["val_mask"], data.ndata["test_mask"]

    # masks_com = g.ndata("train_mask","val_mask","test_mask")
    # src, dst = g.edges()
    # edge_index = torch.stack((src,dst), dim=0)
    # print(edge_index)
    # train_idx, valid_idx, test_idx = rand_train_test_idx(labels, train_prop=0.6, valid_prop=0.2, ignore_negative=True, seed=None)
    # create GCN model
    in_size = features.shape[1]
    out_size = data_raw.num_classes
    model = GCN(in_size, 16, out_size, num_layers=args.GCN_layer).to(device)
    acc_list = []
    l1_list_mean = []
    l1_list_median = []

    timestr = time.strftime("%Y%m%d-%H%M%S")
    filename = ("log/"+ "GCN/"+ str(args.dataset) + str(args.graph_type) + str(args.prob_lambda) + str(args.GCN_layer)+ timestr + ".txt")
    command_args = " ".join(sys.argv)
    if not os.path.exists(filename):
        # 如果文件不存在，首次写入时包含命令
        mode = 'w'  # 写入模式，会创建文件
    else:
        # 如果文件已存在，不重复写入命令
        mode = 'a'  # 追加模式
    with open(filename, mode) as f:
        json.dump(command_args, f)
        f.write("\n")

    for split in range(10):
        train_mask = masks[0][:, split]
        valid_mask = masks[1][:, split]
        #for time in range(2):
        print("Training...")
        g = train(data, features, labels, train_mask, valid_mask, model, train_epoch=args.train_epoch, graph_type=args.graph_type, prob_lambda=args.prob_lambda, seed=args.seed)

        # model testing
        print("Testing....")
        test_mask = masks[2][:, split]
        prob_distribution, acc = evaluate(g, features, labels, test_mask, model)
        acc_list.append(acc*100)
        print("Test accuracy {:.4f}".format(acc))

        num_classes = data_raw.num_classes
        ground_truth_distributions = compute_label_distribution(g, labels, num_classes)[test_mask]
        # calculate distance
        l1_distances = torch.abs(prob_distribution - ground_truth_distributions).sum(1)
        # calculate the mean and median of distance
        l1_list_mean.append(np.mean(l1_distances.cpu().numpy()))
        l1_list_median.append(np.median(l1_distances.cpu().numpy()))

    print(acc_list)
    acc_mean = np.mean(acc_list)
    acc_var = np.std(acc_list)
    print("Test accuracy mean {:.4f}".format(acc_mean))
    print("Test accuracy variation {:.4f}".format(acc_var))

    # train(g, features, labels, masks, model, train_epoch=args.train_epoch)
    # predicted_labels, acc = evaluate(g, features, labels, masks[2], model)

    # predicted_distributions = compute_label_distribution(g, predicted_labels,num_classes)


    # print("L1 distances for all nodes:", l1_distances)
    l1_mean = np.mean(l1_list_mean)
    l1_median = np.mean(l1_list_median)
    print("L1 distances mean for all nodes:", l1_mean)
    print("L1 distances median among all nodes:", l1_median)



    with open(filename, 'a') as f:
        f.write("Test accuracy mean: ")
        json.dump(str(acc_mean), f)
        f.write("\n")

        f.write("Test accuracy variation: ")
        json.dump(str(acc_var), f)
        f.write("\n")

        f.write("Training epoch: ")
        json.dump(str(args.train_epoch), f)
        f.write("\n")

        f.write("GCN_layer: ")
        json.dump(str(args.GCN_layer), f)
        f.write("\n")

        f.write("graph_type: ")
        json.dump(str(args.graph_type), f)
        f.write("\n")

        f.write("prob_lambda: ")
        json.dump(str(args.prob_lambda), f)
        f.write("\n")

        f.write("L1 distance mean for all nodes:")
        json.dump(str(l1_mean), f)
        f.write("\n")
        f.write("L1 distances median among all nodes:")
        json.dump(str(l1_median), f)
        f.write("\n")



