import torch
from torch_geometric.utils import from_networkx, to_networkx
import numpy as np
import networkx as nx
from torch_geometric.utils import (negative_sampling, remove_self_loops,
                                   add_self_loops)
from torch_sparse import SparseTensor
import os, sys
import pickle as pkl
from model.encoder import *

def get_real_adj(data):
    # get real adj matrix
    num_neg = 1
    edge_index = data.edge_index
    edge_index, _ = remove_self_loops(edge_index)
    neigh_adj_sparse = SparseTensor(row=edge_index[0], col=edge_index[1],
                        value=torch.ones(len(edge_index[0]), device=edge_index.device),
                        sparse_sizes=(data.num_nodes, data.num_nodes))
    edge_index, _ = add_self_loops(edge_index)
    real_adj_sparse = SparseTensor(row=edge_index[0], col=edge_index[1],
                        value=torch.ones(len(edge_index[0]), device=edge_index.device),
                        sparse_sizes=(data.num_nodes, data.num_nodes))
    real_adj = real_adj_sparse.to_dense()
    return real_adj



def load_real_world(args):

    DATA_DIR = args.data_dir + '/'
    FILE_NAME = "{}.pkl".format(args.DS)
    FILE_PATH = os.path.join(DATA_DIR, FILE_NAME)
    
    with open(FILE_PATH, 'rb') as file:
        data = pkl.load(file)

    anomaly_flag = data.anomaly_flag.numpy()
    anomaly_label = np.where(anomaly_flag==True,1,0)
    return data, anomaly_label


import matplotlib.pyplot as plt
import numpy as np
import torch
import json
import os
import random
import shutil
import scipy.io as sio
import scipy.sparse as sp
import pickle as pkl
from torch_geometric.utils import (negative_sampling, remove_self_loops,
                                   add_self_loops)
from torch_sparse import SparseTensor

from torch_geometric.data import Data
 


def dense_to_one_hot(labels_dense, num_classes):
    """Convert class labels from scalars to one-hot vectors."""
    num_labels = labels_dense.shape[0]
    index_offset = np.arange(num_labels) * num_classes
    labels_one_hot = np.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset+labels_dense.ravel()] = 1
    return labels_one_hot

def sparse_to_tuple(sparse_mx, insert_batch=False):
    """Convert sparse matrix to tuple representation."""
    """Set insert_batch=True if you want to insert a batch dimension."""
    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):
            mx = mx.tocoo()
        if insert_batch:
            coords = np.vstack((np.zeros(mx.row.shape[0]), mx.row, mx.col)).transpose()
            values = mx.data
            shape = (1,) + mx.shape
        else:
            coords = np.vstack((mx.row, mx.col)).transpose()
            values = mx.data
            shape = mx.shape
        return coords, values, shape

    if isinstance(sparse_mx, list):
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)

    return sparse_mx
def preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return features.todense(), sparse_to_tuple(features)

def load_mat(dataset, train_rate=0.3, val_rate=0.1):
    data = sio.loadmat("./data/{}.mat".format(dataset))
    label = data['Label'] if ('Label' in data) else data['gnd']
    attr = data['Attributes'] if ('Attributes' in data) else data['X']
    network = data['Network'] if ('Network' in data) else data['A']

    adj = sp.csr_matrix(network)
    feat = sp.lil_matrix(attr)

    labels = np.squeeze(np.array(data['Class'],dtype=np.int64) - 1)
    num_classes = np.max(labels) + 1
    labels = dense_to_one_hot(labels,num_classes)

    ano_labels = np.squeeze(np.array(label))
    # print('0',len(np.where(ano_labels==0)[0]))

    if 'str_anomaly_label' in data:
        str_ano_labels = np.squeeze(np.array(data['str_anomaly_label']))
        attr_ano_labels = np.squeeze(np.array(data['attr_anomaly_label']))
    else:
        str_ano_labels = None
        attr_ano_labels = None

    num_node = adj.shape[0]
    num_train = int(num_node * train_rate)
    num_val = int(num_node * val_rate)
    all_idx = list(range(num_node))
    random.shuffle(all_idx)
    idx_train = all_idx[ : num_train]
    idx_val = all_idx[num_train : num_train + num_val]
    idx_test = all_idx[num_train + num_val : ]

    return adj, feat, labels, ano_labels, str_ano_labels, attr_ano_labels

def load_dataset(dataset):
    print('Dataset: {}'.format(dataset), flush=True)
    adj, features, _, ano_label, str_ano_label, attr_ano_label = load_mat(dataset)
    features, _ = preprocess_features(features)
    features = torch.tensor(features, dtype = torch.float32)
    print(ano_label.shape)
    print(np.array(ano_label).tolist().count(0))
    print(np.array(ano_label).tolist().count(1))
    print(np.array(ano_label).tolist().count(0)+np.array(ano_label).tolist().count(1))
    # print(features)

    src, dst = np.nonzero(adj)
    edge_index = np.zeros((2,len(src)))
    edge_index[0,:] = src
    edge_index[1,:] = dst
    edge_index = torch.tensor(edge_index, dtype = torch.int64)
    data = Data(x=features, edge_index=edge_index, anomaly_flag = ano_label)
    return data, features, ano_label, str_ano_label, attr_ano_label


def node_normalize(features):
    adj = torch.eye(features.shape[0])
    aggregate_opt = GraphConv()
    weight = aggregate_opt(features, adj)
    weight = torch.diag(weight)
    normalized_f = (features.swapaxes(1, 0) * weight).swapaxes(1, 0).detach()
    return normalized_f