import ssl
import sys
import torch
import urllib
import warnings
import numpy as np
import pickle as pkl
import os.path as osp
import scipy.sparse as sp
import torch.nn.functional as F

from torch import Tensor
from torch import FloatTensor
from scipy.sparse import coo_matrix
from torch_scatter import scatter_add
from sklearn.preprocessing import StandardScaler
from torch_geometric.utils import remove_self_loops
from torch_geometric.utils import to_scipy_sparse_matrix


def file_exist(filepaths):
    if isinstance(filepaths, list):
        for filepath in filepaths:
            if not osp.exists(filepath):
                return False
        return True
    else:
        if osp.exists(filepaths):
            return True
        else:
            return False
        

def pkl_read_file(filepath):
    file = None
    with open(filepath, 'rb') as rf:
        try:
            if sys.version_info > (3, 0):
                file = pkl.load(rf, encoding="latin1")
            else:
                file = pkl.load(rf)
        except IOError as e:
            print(e)
            exit(1)
    return file


def pkl_write_file(file, filepath):
    with open(filepath, 'wb') as rf:
        try:
            pkl.dump(file, rf)
        except IOError as e:
            print(e)
            exit(1)


def download_to(url, path):
    context = ssl._create_unverified_context()
    data = urllib.request.urlopen(url, context=context)

    with open(path, 'wb') as wf:
        try:
            wf.write(data.read())
        except IOError as e:
            print(e)
            exit(1)

def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index


def to_undirected(edge_index):
    if isinstance(edge_index, sp.csr_matrix) or isinstance(edge_index, sp.coo_matrix):
        row, col = edge_index.row, edge_index.col
        row, col = torch.from_numpy(row), torch.from_numpy(col)
    else:
        row, col = edge_index
        if not isinstance(row, Tensor) or not isinstance(col, Tensor):
            row, col = torch.from_numpy(row), torch.from_numpy(col)
    new_row = torch.hstack((row, col))
    new_col = torch.hstack((col, row))
    new_edge_index = torch.stack((new_row, new_col), dim=0)
    return new_edge_index


def coomatrix_to_torch_tensor(edge_index):
    if isinstance(edge_index, sp.csr_matrix) or isinstance(edge_index, sp.coo_matrix):
        row, col = edge_index.row, edge_index.col
        row, col = torch.from_numpy(row), torch.from_numpy(col)
    else:
        row, col = edge_index
    edge_index = torch.stack((row, col), dim=0)
    return edge_index


def remove_self_loops(edge_index, edge_attr=None):
    mask = edge_index[0] != edge_index[1]
    edge_index = edge_index[:, mask]
    if edge_attr is None:
        return edge_index, None
    else:
        return edge_index, edge_attr[mask]


def remove_self_loops_weights(edge_weight, edge_index):
    mask = edge_index[0] != edge_index[1]
    edge_weight = edge_weight[mask]
    return edge_weight


def set_in_out_degree_features(num_node, edge_index, signed, edge_weight):
    cpu_edge_index = edge_index.cpu()
    if signed:
        if edge_weight is None:
            raise ValueError('Edge weight input should not be None when generating features based on edge signs!')
        else:
            edge_weight = edge_weight.cpu().numpy()
        A = coo_matrix((edge_weight, (cpu_edge_index[0], cpu_edge_index[1])),
                    shape=(num_node, num_node), dtype=np.float32).tocsr()
        A_abs = A.copy()
        A_abs.data = np.abs(A_abs.data)
        A_p = (A_abs + A)/2
        A_n = (A_abs - A)/2
        out_pos_degree = np.sum(A_p, axis=0).T
        out_neg_degree = np.sum(A_n, axis=0).T
        in_pos_degree = np.sum(A_p, axis=1)
        in_neg_degree = np.sum(A_n, axis=1)
        degree = torch.from_numpy(np.c_[in_pos_degree, in_neg_degree, out_pos_degree, out_neg_degree]).float()
    else:
        if edge_weight is None:
            edge_weight = np.ones(len(cpu_edge_index.T))
        else:
            edge_weight = np.abs(edge_weight.cpu().numpy())
        A = coo_matrix((edge_weight, (cpu_edge_index[0], cpu_edge_index[1])),
                    shape=(num_node, num_node), dtype=np.float32).tocsr()
        out_degree = np.sum(A, axis=0).T
        in_degree = np.sum(A, axis=1)
        degree = torch.from_numpy(np.c_[in_degree, out_degree]).float()
    return degree

def feature_normalization(features):
    features = F.normalize(features, p=1, dim=1)
    features = features.numpy()
    m = features.mean(axis=0)
    s = features.std(axis=0, ddof=0, keepdims=True) + 1e-12
    features -= m
    features /= s
    return torch.FloatTensor(features)
