import os
import json
import random
import pandas as pd
import numpy as np
import torch
import scipy.sparse as sp

def convert_index_to_int(adj_lists):
    """Function to convert the node indices to int
    """
    new_adj_lists = {}
    for node, neigh in adj_lists.items():
        new_adj_lists[int(node)] = neigh

    return new_adj_lists


def set_seed(seed):
    """function sets the seed value

    Args:
        seed (int): seed value
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def change_graph(model, adj_lists):
    """The function is used to change graph

    Args:
        model (nn.Model): bilinear model
        adj_lists (dict): the weighted graph
    """
    model.label_encoder.enc2.adj_lists = adj_lists
    model.label_encoder.enc1.adj_lists = adj_lists


def get_save_path(dir_path, options):

    save_path = options['label_encoder_type']

    if 'n1' in options and 'n2' in options:
        if options['n1'] or options['n2']:
            save_path += '_n1_{}_n2_{}'.format(options['n1'], options['n2'])

    #
    if options['pd1'] != 150 or options['pd2'] != 64 or \
        options['fh1'] != 150 or options['fh2'] != 64:
        save_path += '_pd1_' + str(options['pd1'])
        save_path += '_pd2_' + str(options['pd2'])
        save_path += '_fh1_' + str(options['fh1'])
        save_path += '_fh2_' + str(options['fh2'])

    if options['decay'] != 0.0:
        save_path += '_decay_' + str(options['decay'])

    if options['num_layers'] != 1:
        save_path += '_layers_' + str(options['num_layers'])

    if options['dp1'] != 0.1 or options['dp2'] != 0.1:
        save_path += '_dp1_' + str(options['dp1'])
        save_path += '_dp2_' + str(options['dp2'])

    save_path += '_seed_'+ str(options['seed'])

    save_path += '.pt'

    save_path = os.path.join(dir_path, save_path)

    return save_path


def create_dirs(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def init_device(gpu=0):
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu))
        cuda_device = gpu
    else:
        device = torch.device('cpu')
        cuda_device = -1
    return device, cuda_device


def fine_labels(dataset, level):
    labels = []
    train_labels = []
    test_labels = []
    # TODO: port to 3 levels too
    for line in dataset:
        #
        example = json.loads(line)
        labels += example['labels']
    #
    labels = list(set(labels))
    for label in labels:
        if len(label.split('/')) == level + 1:
            test_labels.append(label)
        else:
            train_labels.append(label)

    return train_labels, test_labels


def flatten_dataset(lines):
    dataset = []
    for line in lines:
        _temp = []
        example = json.loads(line)
        for mention in example['mentions']:
            instance = {
                'tokens': example['tokens'],
                'senid': example['senid'],
                'fileid': example['fileid'],
                'labels': mention['labels'],
                'start': mention['start'],
                'end': mention['end']
            }
            _temp.append(json.dumps(instance))

        dataset += _temp

    return dataset


def remove_labels(dataset, train_labels):
    clean_dataset = []
    for line in dataset:
        example = json.loads(line)
        _labels = [_label for _label in example['labels'] \
                        if _label in train_labels]
        if _labels:
            example['labels'] = _labels
            clean_dataset.append(json.dumps(example))

    return clean_dataset


def normt_spm(mx, method='in'):
    if method == 'in':
        mx = mx.transpose()
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.diags(r_inv)
        mx = r_mat_inv.dot(mx)
        return mx

    if method == 'sym':
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -0.5).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.diags(r_inv)
        mx = mx.dot(r_mat_inv).transpose().dot(r_mat_inv)
        return mx


def spm_to_tensor(sparse_mx):
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(np.vstack(
            (sparse_mx.row, sparse_mx.col))).long()
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)
