import pickle as pkl
import sys

import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F

SMALL = 1e-10


def debug_print(content=None, ckpt=''):
    s1 = "DEBUG-INFO"
    s2 = "\n CHECKPOINT = {}".format(ckpt)
    s3 = "\n CONTENT = \'{}\'".format(content)
    print(s1 + s2 + s3 + '\n')


def debugger(x, tensor_name='', ckpt='', abnormal_exit=False):
    y = x.detach().cpu().numpy()
    nan_num = np.sum(np.isnan(y))
    inf_num = np.sum(np.isinf(y))
    nan_stat, inf_stat = (nan_num > 0), (inf_num > 0)

    if (nan_stat or inf_stat) and (ckpt is not None):
        content = "{} nan and {} inf appears in {}".format(nan_num, inf_num, tensor_name)
        debug_print(content = content, ckpt = ckpt)
    
    if (nan_stat or inf_stat) and abnormal_exit:
        exit()
    
    return np.array([nan_stat, inf_stat]).astype(np.int64)




def ck_relative_(tensor_1_name, tensor_1, tensor_2_name, tensor_2, threshold=0., mute_t1_info=False):
    """
    check tensor_2's value where tensor_1 exceeds threshold.
    """
    tensor1_np = tensor_1.detach().cpu().numpy()
    tensor2_np = tensor_2.detach().cpu().numpy()

    # shape match
    if tensor2_np.shape[1] == 1:
        tensor2_np = np.repeat(tensor2_np, tensor1_np.shape[1], axis=1)
    
    assert tensor1_np.shape == tensor2_np.shape, 'two tensors have different shapes'

    tensor1_np = tensor1_np.flatten()
    tensor2_np = tensor2_np.flatten()

    nan_map = np.isnan(tensor1_np)
    gth_map = (np.abs(tensor1_np) > threshold)

    if nan_map.sum() > 0:
        print("{:s}'s value where nan detected in {:s}:".format(tensor_2_name, tensor_1_name))
        display_array = tensor2_np[nan_map.nonzero()]
        display_array = display_array[:5]
        print("%.4f  " * len(display_array) % tuple(display_array))
    
    if 1:
        gth_part = np.argpartition(np.abs(tensor1_np), -5)
        tensor1_big5 = tensor1_np[gth_part[-5:]]
        display_array = tensor2_np[gth_part[-5:]]

        if not mute_t1_info:
            print("{:s} extreme 5".format(tensor_1_name))
            print("%.4f  " * len(display_array) % tuple(tensor1_big5))
        
        print("and corresponding values in {:s}".format(tensor_2_name))
        print("%.4f  " * len(display_array) % tuple(display_array))




def log_max(input, SMALL=SMALL):
    device = input.device
    input_ = torch.max(input, torch.tensor([SMALL]).to(device))
    return torch.log(input_)
    



def grad_moderator(model, inf_grad_subs, nan_grad_subs):

    """
    replace nan or inf gradient to values
    """
    pos_grad_subs = inf_grad_subs
    neg_grad_subs = -1. * inf_grad_subs

    for name, param in model.named_parameters():
        param_grads = param.grad
        gnan = torch.isnan(param_grads).any()
        ginf = torch.isinf(param_grads).any()
        if  gnan or ginf:
            print("examining parameter: {:s}".format(name))
        
        if gnan:               
            param_grads = torch.where(
                torch.isnan(param_grads), 
                nan_grad_subs.type(param_grads.dtype),
                param_grads
            )
        
        if ginf:
            param_grads = torch.where(
                (param_grads == np.inf),
                pos_grad_subs.type(param_grads.dtype),
                param_grads
            )
            para_grads = torch.where(
                (param_grads == -np.inf),
                neg_grad_subs.type(param_grads.dtype),
                param_grads
            )
        
        param.grad = param_grads



def phir_compare(trained_phir, stored_phir):
    phirt_np_ls = [phir.cpu().numpy() for phir in trained_phir]
    phirs_np_ls = stored_phir

    stat = 0.
    for i, (phirt_np, phirs_np) in enumerate(zip(phirt_np_ls, phirs_np_ls)):
        phir_diff = (phirt_np - phirs_np) ** 2.
        print("\n\ndifference stats in phir feature for community {:d}".format((i+1)))
        print("max = {:2f}, min = {:2f}, mean = {:2f}".format(
            phir_diff.max(), phir_diff.min(), phir_diff.mean()
        ))
        d_stat = phir_diff.max() + phir_diff.min() + phir_diff.mean()
        stat += d_stat
        if d_stat != 0:
            print(d_stat)
    
    return stat



