"""
Copyright 2020 Twitter, Inc.
SPDX-License-Identifier: Apache-2.0
"""
import torch
from torch_scatter import scatter_add
from sklearn.metrics import f1_score
from torch_geometric.utils.convert import to_networkx
import torch.nn.functional as F
import networkx as nx
import numpy as np
import os
import sys
import random
import logging
import warnings
import socket
import requests
import re
warnings.filterwarnings("ignore", category=UserWarning)

def get_missing_feature_mask(data, rate, n_nodes, n_features, type="uniform"):
    """ 
    Return mask of shape [n_nodes, n_features] indicating whether each feature is present or missing.
    If `type`='uniform', then each feature of each node is missing uniformly at random with probability `rate`.
    Instead, if `type`='structural', either we observe all features for a node, or we observe none. For each node
    there is a probability of `rate` of not observing any feature. 
    True: observed
    False: missing
    """
    if type == "row":  # either remove all of a nodes features or none
        return torch.bernoulli(torch.Tensor([1 - rate]).repeat(n_nodes)).bool().unsqueeze(1).repeat(1, n_features)
    elif type == "col":
        return torch.bernoulli(torch.Tensor([1 - rate]).repeat(n_features)).bool().unsqueeze(0).repeat(n_nodes, 1)
    else:
        ## og
        return torch.bernoulli(torch.Tensor([1 - rate]).repeat(n_nodes, n_features)).bool()
        
        ## improved
        # feature_mask = torch.ones(data.shape[0],data.shape[1]).bool()
        # n_observed = (data != 0).sum().item()
        # mask = torch.bernoulli(torch.Tensor([1 - rate]).repeat(n_observed)).bool()
        # mask_idx = (data != 0).nonzero()[~mask]
        # feature_mask[mask_idx[:, 0], mask_idx[:, 1]] = False
        # return feature_mask



def get_mask(idx, num_nodes):
    """
    Given a tensor of ids and a number of nodes, return a boolean mask of size num_nodes which is set to True at indices
    in `idx`, and to False for other indices.
    """
    mask = torch.zeros(num_nodes, dtype=torch.bool)
    mask[idx] = 1
    return mask


def get_symmetrically_normalized_adjacency(edge_index, n_nodes, edge_weight=None):
    """
    Given an edge_index, return the same edge_index and edge weights computed as
    \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}.
    """
    if edge_weight == None:
        edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device)
    row, col = edge_index[0], edge_index[1]
    deg = scatter_add(edge_weight, col, dim=0, dim_size=n_nodes)
    deg_inv_sqrt = deg.pow_(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float("inf"), 0)
    DAD = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    return edge_index, DAD

def get_row_normalized_adjacency(edge_index, n_nodes, edge_weight=None):
    """
    Given an edge_index, return the same edge_index and edge weights computed as
    \mathbf{\hat{D}}^{-1} \mathbf{\hat{A}}.
    """
    if edge_weight == None:
        edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device)
    row, col = edge_index[0], edge_index[1]
    deg = scatter_add(edge_weight, col, dim=0, dim_size=n_nodes)
    deg_inv_sqrt = deg.pow_(-1)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float("inf"), 0)
    DA = deg_inv_sqrt[row] * edge_weight

    return edge_index, DA


def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def performance(output, labels, pre=None, evaluator=None):
    if output.shape != labels.shape:
        if len(labels) == 0:
            return np.nan
        preds = output.max(1)[1].type_as(labels)
    else:
        preds= output
    
    if evaluator:
        acc = evaluator.eval({"y_true": labels, "y_pred": preds.unsqueeze(1)})["acc"]
        acc = acc * 100
    else:
        correct = preds.eq(labels).double()
        acc = correct.sum() * 100 / len(labels)
    
    macro_F = f1_score(labels.cpu().detach(), preds.cpu().detach(), average='macro')*100

    return acc, macro_F

def setup_logger(save_dir, text, filename = 'log.txt'):
    os.makedirs(save_dir, exist_ok=True)
    logger = logging.getLogger(text)
    # for each in logger.handlers:
    #     logger.removeHandler(each)
    logger.setLevel(4)
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(message)s")
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    if save_dir:
        fh = logging.FileHandler(os.path.join(save_dir, filename))
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)
        logger.addHandler(fh)
    logger.info("======================================================================================")

    return logger

def set_filename(args):
    # runs = '_n_runs_10' if args.n_runs == 10 else ''
    runs = f'_n_runs_{args.n_runs}'

    if 'GOODIE' in args.embedder:
        logs_path = f'./logs{runs}/{args.dataset}/ours'
    else:
        logs_path = f'./logs{runs}/{args.dataset}/baseline'
        os.makedirs(logs_path, exist_ok=True)

    logs_path += f'/{args.missing_type}'
    os.makedirs(logs_path, exist_ok=True)
    
    filename = args.embedder
    if args.embedder == 'GNN':
        filename = args.gnn + f'_{args.filling_method}'
    
    if args.embedder == 'LP_label_trick':
        filename = args.embedder + f'_use_coef_{args.use_coef}_{args.filling_method}'

    elif args.label_trick:
        filename += f'_label_trick_mask_rate_{args.mask_rate}'
        if args.n_runs > 0:
            filename += f'_n_reuse_{args.n_reuse}'
        
    if args.train_ratio > 0.0:
        filename += f'_train_ratio_{args.train_ratio}'
    
    # server
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.connect(("pwnbit.kr", 443))
    server = sock.getsockname()[0].split('.')[-1]

    if 'GOODIE' in args.embedder:
        server = '23' if server in ['55', '148', '149'] else server
        scaled = '_scaled' if args.scaled else ''
        # k = '' if args.k == 1 else f'_k_{args.k}'
        k = f'_k_{args.k}'
        file = f'{logs_path}/{filename}{scaled}{k}_lambda_{args.lamb}_lp_temp_{args.lp_temp}_server_{server}'
        if args.ver > 0:
            file += f'_ver_{args.ver}'
        if 'attn' in args.embedder:
            file += f'_attn_type_{args.attn_type}'
        if float(args.pseudo_type) > -1:
            file += f'_pseudo_type_{args.pseudo_type}'
        if (args.epoch_start >= 0) & (args.n_runs == 1):
            file += f'_seed_{args.epoch_start}'
        file += '.txt'
    elif args.embedder == 'Node2Vec':
        # params = f'p_{args.p}_q_{args.q}_walk_length_{args.walk_length}_context_size_{args.context_size}_walks_per_node_{args.walks_per_node}'
        params = f'p_{args.p}_q_{args.q}'
        
        file = f'{logs_path}/{filename}_{params}.txt'
    else:
        file = f'{logs_path}/{filename}.txt'

    return file


# For TWIRLS
from typing import Union, Any
import torch_geometric.data

def to_dgl(
    data: Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData'], ogb=False
) -> Any:
    r"""Converts a :class:`torch_geometric.data.Data` or
    :class:`torch_geometric.data.HeteroData` instance to a :obj:`dgl` graph
    object.

    Args:
        data (torch_geometric.data.Data or torch_geometric.data.HeteroData):
            The data object.

    Example:

        >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]])
        >>> x = torch.randn(5, 3)
        >>> edge_attr = torch.randn(6, 2)
        >>> data = Data(x=x, edge_index=edge_index, edge_attr=y)
        >>> g = to_dgl(data)
        >>> g
        Graph(num_nodes=5, num_edges=6,
            ndata_schemes={'x': Scheme(shape=(3,))}
            edata_schemes={'edge_attr': Scheme(shape=(2, ))})

        >>> data = HeteroData()
        >>> data['paper'].x = torch.randn(5, 3)
        >>> data['author'].x = torch.ones(5, 3)
        >>> edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])
        >>> data['author', 'cites', 'paper'].edge_index = edge_index
        >>> g = to_dgl(data)
        >>> g
        Graph(num_nodes={'author': 5, 'paper': 5},
            num_edges={('author', 'cites', 'paper'): 5},
            metagraph=[('author', 'paper', 'cites')])
    """
    import dgl

    from torch_geometric.data import Data

    if isinstance(data, Data):
        if data.edge_index is not None:
            row, col = data.edge_index
        else:
            row, col, _ = data.adj_t.t().coo()

        g = dgl.graph((row, col))

        for attr in data.keys:
            if attr in ['edge_index', 'adj_t']:
                continue
            if attr == 'y':
                g.ndata['label'] = data[attr]
            elif attr == 'x':
                g.ndata['feat'] = data[attr]
                g.ndata['feature'] = data[attr]
            else:
                if ogb:
                    n_nodes = data.x.shape[0]
                    tmp = torch.zeros(n_nodes).to(data.x.device) == -1
                    tmp[data[attr]] = True
                    g.ndata[attr] = tmp
                else:
                    g.ndata[attr] = data[attr]

        return g

    raise ValueError(f"Invalid data type (got '{type(data)}')")

# for Label Trick
def add_labels(feat, labels, idx, n_classes, scale=1, ogbn=False):
    feat = feat
    labels = labels.squeeze(1) if ogbn else labels
    onehot = torch.zeros([feat.shape[0], n_classes]).to(feat.device)
    onehot[idx, labels[idx]] = scale
    return torch.cat([feat, onehot], dim=-1)

def knn_fast(X, k, b=4096):
    device = X.device
    X = F.normalize(X, dim=1, p=2)
    index = 0
    values = torch.zeros(X.shape[0] * (k + 1)).to(device)
    rows = torch.zeros(X.shape[0] * (k + 1)).to(device)
    cols = torch.zeros(X.shape[0] * (k + 1)).to(device)
    norm_row = torch.zeros(X.shape[0]).to(device)
    norm_col = torch.zeros(X.shape[0]).to(device)
    
    while index < X.shape[0]:
        if (index + b) > (X.shape[0]):
            end = X.shape[0]
        else:
            end = index + b
        sub_tensor = X[index:index + b]
        similarities = torch.mm(sub_tensor, X.t())
        vals, inds = similarities.topk(k=k + 1, dim=-1)
        values[index * (k + 1):(end) * (k + 1)] = vals.view(-1)
        cols[index * (k + 1):(end) * (k + 1)] = inds.view(-1)
        rows[index * (k + 1):(end) * (k + 1)] = torch.arange(index, end).view(-1, 1).repeat(1, k + 1).view(-1)
        norm_row[index: end] = torch.sum(vals, dim=1)
        norm_col.index_add_(-1, inds.view(-1), vals.view(-1))
        index += b
    
    norm = norm_row + norm_col
    rows = rows.long()
    cols = cols.long()
    values *= (torch.pow(norm[rows], -0.5) * torch.pow(norm[cols], -0.5))
    
    return rows, cols, values