import argparse
import os.path as osp
import random
from time import perf_counter as t
import yaml
import json
from yaml import SafeLoader

import torch
import torch_geometric.transforms as T
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import Planetoid, CitationFull
from torch_geometric.utils import dropout_edge
from torch_geometric.nn import GCNConv, SGConv

from collections import defaultdict
from model import Encoder, Model, drop_feature, Decoder
from validation_concatenation import fs_dev, fs_test

from data_split import *
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import normalize

import numpy as np
from copy import deepcopy
from math import sqrt
from args import get_args
import networkx as nx
from sklearn.cluster import KMeans
from GCN_emb import GCN_emb
from GCN_motif import MotifEncoder
import pickle
import os
import sys
import time


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
    )
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def get_wl_hash_with_edge_direction(motif):
    for u, v in motif.edges():
        motif[u][v]['direction'] = 'forward'  # Mark all edges as "forward"
    return nx.weisfeiler_lehman_graph_hash(motif, edge_attr='direction')

def get_n_hop_subgraph(adj, i):
    idx_in = adj[:,i].nonzero()
    idx_out = adj[i].nonzero()
    idx_total = torch.unique(torch.concatenate((idx_in, idx_out)))
    x = torch.tensor([i])
    idx_total = torch.concatenate((x,idx_total)).squeeze()
    if idx_total.ndim == 0:
        idx_total = idx_total.unsqueeze(0)
    original_new_node = defaultdict()
    new_original_node = defaultdict()
    for i, n in enumerate(idx_total):
        original_new_node[n.item()] = i 
        new_original_node[i] = n.item()
    n_hop_cur= adj[idx_total, :][:, idx_total]
    return n_hop_cur, original_new_node, new_original_node


def get_3_node_subgraph_undirected(G, motif, o_n, n_o, node_num, max_node, seen, motif_adj, motif_x):
    current_node_idx = node_num
    for i in G.neighbors(current_node_idx):
        if i == current_node_idx:
            continue
        for j in G.neighbors(i):
            if j == i or j == current_node_idx:
                continue
            else:
                subgraph_list = [current_node_idx,i,j]
                subgraph_list_sort = tuple(sorted(subgraph_list))
                if subgraph_list_sort in seen:
                    continue
                else:
                    seen.add(subgraph_list_sort)
                subgraph = G.subgraph(subgraph_list)
                hash = nx.weisfeiler_lehman_graph_hash(subgraph,node_attr="Labels")
                if hash not in motif.keys():
                    motif[hash] = np.zeros(max_node)
                    motif_adj[hash] = []
                    motif_x[hash] = []
                motif[hash][i] += 1
                motif[hash][current_node_idx] += 1
                motif[hash][j] += 1
                current_adj = nx.to_numpy_array(subgraph)
                adj_tensor = torch.tensor(current_adj, dtype=torch.float32)
                motif_adj[hash].append(adj_tensor)
                degrees = adj_tensor.sum(dim=1)
                motif_x[hash].append(degrees)
    return motif




def add_edge_with_tf_idf(Graph, node, motif_list, motif_dict, total_num):
    motif_len = len(motif_list.nonzero()[0])
    idf = np.log(total_num/motif_len)
    for i , a in enumerate(motif_list):
        if a == 0:
            continue
        tf = np.log(a)
        tf_idf = tf*idf
        if tf_idf == 0:
            continue
        Graph.add_edge(i, node, weight = tf_idf)
    return Graph    


def relabeling(labels, train_class, dev_class, test_class, id_by_class):
    print("Start relabeling...")
    labels = labels.tolist()
    contrast_labels = deepcopy(labels)
    masked_class = dev_class + test_class
    masked_idx = []
    for cla in masked_class:
        masked_idx.extend(id_by_class[cla])

    train_class.sort()
    train_class_map = {i: train_class.index(i) for i in train_class}

    tmp_class = len(train_class)
    for cla, idx_list in id_by_class.items():
        if cla in train_class:
            for idx in idx_list:
                contrast_labels[idx] = train_class_map[cla]
        else:
            for idx in idx_list:
                contrast_labels[idx] = tmp_class
                tmp_class += 1
    print("Relabeling finished!")
    return contrast_labels

def train(model: Model, x, contrast_labels, edge_index, motif_emb, adj_motif, motif_num, optimizer):
    model.train()
    optimizer.zero_grad()
    edge_index_1 = dropout_edge(edge_index, p=drop_edge_rate_1)[0]
    edge_index_2 = dropout_edge(edge_index, p=drop_edge_rate_2)[0]
    x_1 = drop_feature(x, drop_feature_rate_1)
    x_2 = drop_feature(x, drop_feature_rate_2)
    z1 = model(x_1, edge_index_1, motif_emb, adj_motif, motif_num) # N x d
    z2 = model(x_2, edge_index_2, motif_emb, adj_motif, motif_num) # N x d
    device = z1.device
    
    z1_np, z2_np = z1.cpu(), z2.cpu()
    sim = z1_np @ z2_np.T
    
    topnk_idx = torch.topk(sim, k=args.topk, dim=1, largest=True).indices # N x k
    topk_idx = topnk_idx[:, torch.randperm(topnk_idx.size(1))[:args.topk]]
    set_emb = z2_np[topk_idx] # N x k x d
    
    set_1, set_2 = torch.split(set_emb, args.topk//2, dim=1) # N x k/2 x d
    set_1, set_2 = set_1.to(device), set_2.to(device)
    
    zs_1, zs_2 = model.set_forward(set_1), model.set_forward(set_2)
    
    if args.sup == False:
        #print("unsupervised cl loss")
        ins_loss = model.loss(z1, z2)
        set_loss = model.loss(zs_1, zs_2, loss_type = 'set')
        loss = alpha * ins_loss + beta * set_loss

    loss.backward()
    optimizer.step()
    return loss.item()


def train_eval():
    dataset, dataset_undirected, train_idx, id_by_class, train_class, dev_class, test_class = split(args.dataset)
    if args.dataset in ['coauthor-cs',  'Cora']:
        data = dataset[0]
    else:
        data = dataset
    original_node_num = data.x.shape[0]
    edge_weight = torch.ones(data.edge_index.size(1), dtype=torch.float)
    adj_sparse = torch.sparse_coo_tensor(
        indices=data.edge_index,
        values=edge_weight,
        size=(original_node_num, original_node_num)
    ).coalesce()   
    adj = adj_sparse.to_dense()
    feature_np = data.x.numpy()
    feature_np_normalize = normalize(feature_np, axis=1)
    kmeans = KMeans(n_clusters=15, random_state=42)
    temp_labels = kmeans.fit_predict(feature_np_normalize)
    G = nx.Graph(adj.numpy())
    G = nx.Graph(adj.to_dense().numpy())
    label_dict = {i: temp_labels[i] for i in range(len(temp_labels))}
    nx.set_node_attributes(G, label_dict, name='Labels')
    seen = set()
    motif_3 = {}
    motif_list_adj = {}
    motif_list_x = {}
    for i in range(original_node_num):
        subgraph_1_hop, original_new_node, new_original_node = get_n_hop_subgraph(adj, i)
        subgraph_list = list(original_new_node.keys())
        sub_G = G.subgraph(subgraph_list)
        motif_3 = get_3_node_subgraph_undirected(sub_G,motif_3,original_new_node, new_original_node, i,original_node_num, seen, motif_list_adj,motif_list_x)

    node_count = 0
    for i , k in enumerate(motif_3.keys()):
        node1 = int(original_node_num + node_count)
        G.add_node(node1)
        G = add_edge_with_tf_idf(G, node1, motif_3[k], motif_3, original_node_num)
        node_count += 1 
    final_node = int(max(G.nodes))
   

    e_weight = []
    for i in range(original_node_num+1, final_node+1):
        for j in G[i]:
            e_weight.append(G[i][j]['weight'])
    max_weight = max(e_weight)
    min_weight = min(e_weight)
    new_weight = []
    for i in range(original_node_num+1, final_node+1):
        for j in G[i]:
            G[i][j]['weight'] = (G[i][j]['weight'] - min_weight)/(max_weight-min_weight)*30
            new_weight.append(G[i][j]['weight'])
        threshold = np.percentile(new_weight, 35)
        nodes_of_interest = set(range(original_node_num+1, final_node+1))
        edges_to_remove = [ (u, v) for u, v, d in G.edges(data=True) if (u in nodes_of_interest or v in nodes_of_interest) and d.get('weight', 0) < threshold]
        G.remove_edges_from(edges_to_remove)
    adj_motif = nx.adjacency_matrix(G)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("device is ", device)
    adj_motif = sparse_mx_to_torch_sparse_tensor(adj_motif+ sp.eye(adj_motif.shape[0]))
    adj_motif = adj_motif.coalesce().indices().to(device)
    motif_num = node_count

    encoder = Encoder(dataset.num_features, num_hidden, motif_num, activation,
                    base_model=base_model, k=num_layers)
    decoder = Decoder(set_name, num_hidden, num_heads)
    model = Model(encoder, decoder, num_hidden, num_proj_hidden, tau).to(device)


    data = data.to(device)
    
    contrast_labels = relabeling(data.y, train_class, dev_class, test_class, id_by_class)
    contrast_labels = torch.LongTensor(contrast_labels).to(device)

   
    base_params = []
    special_params = []
    for name, param in model.named_parameters():
        if 'motif' in name:
            special_params.append(param)
        else:
            base_params.append(param)
    optimizer = torch.optim.Adam([
            {'params': base_params,  'lr': learning_rate, 'weight_decay': weight_decay},
            {'params': special_params, 'lr': 0.005, 'weight_decay': 0.005}
            ],
   
        )

    cnt_wait = 0
    best_acc = 0
    for epoch in range(1, num_epochs + 1):
        motif_emb_total = []
        for k in motif_list_adj:
            current_motif_emb = []
        motif_emb_total = None
        _ = train(model, data.x, contrast_labels, data.edge_index, motif_emb_total, adj_motif, motif_num, optimizer)
        if (epoch - 1) % 10:
            final_mean, final_std = fs_dev(model, data.x, data.edge_index, motif_emb_total, adj_motif, motif_num, data.y, setting['test_num'], id_by_class, dev_class, setting['n_way'], setting['k_shot'], setting['m_qry'])
            print("===="*20)
            print("novel_dev_acc: " + str(final_mean))
            print("novel_dev_std: " + str(final_std))
            if best_acc < final_mean:
                best_acc = final_mean
                cnt_wait = 0
                torch.save(model.state_dict(), './savepoint/'+args.dataset+'_model.pkl')
            else:
                cnt_wait += 1

        if cnt_wait == setting['patience']:
            print('Early stopping!')
            break


    print("=== Final Test ===")
    path = './savepoint/'
    model.load_state_dict(torch.load(path+args.dataset+'_model.pkl'))
    print("model load success!")
    final_mean, final_std = fs_test(model, data.x, data.edge_index, motif_emb_total, adj_motif, motif_num, data.y, train_idx, setting['test_num'], id_by_class, test_class, setting['n_way'], setting['k_shot'], setting['m_qry'], device, args)
    print("novel_test_acc: " + str(final_mean))
    print("novel_test_std: " + str(final_std))

    return final_mean, final_std
    
    
if __name__ == '__main__':

    args = get_args()
    print(args)
    setting = {'n_way': args.way, 'k_shot': args.shot, 'm_qry': args.query, 'test_num': args.test_num, 'patience': args.patience}
    config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset]

    torch.manual_seed(config['seed'])
    random.seed(config['seed'])
    np.random.seed(config['seed'])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config['seed'])
        torch.cuda.manual_seed_all(config['seed'])

    learning_rate = config['learning_rate']
    num_hidden = config['num_hidden']
    num_proj_hidden = config['num_proj_hidden']
    activation = ({'relu': F.relu, 'prelu': nn.PReLU()})[config['activation']]
    if args.encoder == 'gcn':
        base_model = GCNConv
    elif args.encoder == 'sgc':
        base_model = SGConv
    num_layers = config['num_layers']

    drop_edge_rate_1 = args.drop_edge_rate_1
    drop_edge_rate_2 = args.drop_edge_rate_2
    drop_feature_rate_1 = args.drop_feature_rate_1
    drop_feature_rate_2 = args.drop_feature_rate_2
    tau = config['tau']
    num_epochs = config['num_epochs']
    weight_decay = config['weight_decay']
    set_name = args.set_encoder
    num_heads = args.num_heads
    alpha = args.alpha
    beta = args.beta

    acc_mean = []
    acc_std = []

    for __ in range(3):
        m, s = train_eval()
        acc_mean.append(m)

    print("======"*10)
    print("acc mean: " + str(np.mean(acc_mean)))
    print("acc std: " + str(np.std(acc_mean)))
    
    result=defaultdict(list)
    result[tuple([np.mean(acc_mean),np.std(acc_mean)])] = {
    'way': args.way,
    'shot': args.shot,
    'drop_edge_rate_1': args.drop_edge_rate_1,
    'drop_edge_rate_2': args.drop_edge_rate_2,
    'drop_feature_rate_1': args.drop_feature_rate_1,
    'drop_feature_rate_2': args.drop_feature_rate_2,
    'temperature': args.temperature,
    'beta': args.beta,
    'gamma': args.gamma,
    'alpha': args.alpha,
    'epoch': args.epoch,
    'weight_decay': args.wd}
    with open("./res/" + args.dataset+"_res.txt", "a+") as f:
        f.write(json.dumps({str(k): result[k] for k in result}, indent=4))
        f.close()
