import torch
import math
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from torch_geometric.data import Batch
import torch_geometric.utils as tutils
import numpy as np


def collate(graphs):
    return Batch.from_data_list(graphs)


def collate_bl(batch):
    graphs, bl_vals = map(list, zip(*batch))
    batch_graph = Batch.from_data_list(graphs)
    bl_vals = torch.tensor(bl_vals)
    return batch_graph, bl_vals



def move_to(var, device):
    if isinstance(var, dict):
        return {k: move_to(v, device) for k, v in var.items()}
    if isinstance(var, list):
        return [move_to(v, device) for v in var]
    return var.to(device)


def get_inner_model(model):
    return model.module if isinstance(model, DataParallel) else model


def log_values(cost, grad_norms, epoch, batch_id, step,
               log_likelihood, reinforce_loss, tb_logger):
    avg_cost = cost.mean().item()
    grad_norms, grad_norms_clipped = grad_norms

    print('epoch: {}, train_batch_id: {}, avg_cost: {}'.format(epoch, batch_id, avg_cost))

    print('clipped grad_norms: {}'.format(grad_norms[0]))
    print('nll: ', -log_likelihood.mean().item())

    # Log values to tensorboard
    if tb_logger is not None:
        tb_logger.log_value('avg_cost', avg_cost, step)

        tb_logger.log_value('actor_loss', reinforce_loss.item(), step)
        tb_logger.log_value('nll', -log_likelihood.mean().item(), step)

        tb_logger.log_value('grad_norm', grad_norms[0], step)
        tb_logger.log_value('grad_norm_clipped', grad_norms_clipped[0], step)


def get_neighbors(edge_index, node, node_degree=None):

    row, col = edge_index[0], edge_index[1]
    if node_degree is None:
        node_degree = tutils.degree(col)[node]

    start_row_ind = tensor_bs(node, row)
    if start_row_ind == -1:
        return torch.tensor([])
    neighbors = col[start_row_ind:start_row_ind+node_degree]
    return neighbors

# Binary search
def tensor_bs(value, elements):
    left, right = 0, len(elements) - 1

    while left <= right:
        middle = (left + right) // 2

        if elements[middle] == value:
            while middle >= 0 and elements[middle] == value:
                middle -= 1
            return middle+1

        if elements[middle] < value:
            left = middle + 1
        elif elements[middle] > value:
            right = middle - 1

    return -1


def get_adjacency_list(graph, device):
    n = graph.num_nodes

    edges = graph.edge_index[1]
    degs = tutils.degree(graph.edge_index[1], graph.num_nodes).view(graph.num_nodes, 1).int().cpu().numpy()
    pref_deg = np.zeros(n+1, dtype=np.int)
    for i in range(1, n+1):
        pref_deg[i] = pref_deg[i-1] + degs[i-1]

    adj = [edges[pref_deg[i]:pref_deg[i+1]].cpu().numpy() for i in range(n)]

    return adj