import os
import re
import json
import os.path as osp
import shutil

import numpy as np
import scipy.sparse as sp
import torch
import random

from collections import Counter

from IPython import embed


def ensure_path(path):
    if osp.exists(path):
        if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
            shutil.rmtree(path)
            os.mkdir(path)
    else:
        os.mkdir(path)


def set_gpu(gpu):
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    print('using gpu {}'.format(gpu))


def pick_vectors(dic, wnids, is_tensor=False):
    o = next(iter(dic.values()))
    dim = len(o)
    ret = []
    for wnid in wnids:
        v = dic.get(wnid)
        if v is None:
            if not is_tensor:
                v = [0] * dim
            else:
                v = torch.zeros(dim)
        ret.append(v)
    if not is_tensor:
        return torch.FloatTensor(ret)
    else:
        return torch.stack(ret)


def normt_spm(mx, method='in'):
    if method == 'in':
        mx = mx.transpose()
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.diags(r_inv)
        mx = r_mat_inv.dot(mx)
        return mx

    if method == 'sym':
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -0.5).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.diags(r_inv)
        mx = mx.dot(r_mat_inv).transpose().dot(r_mat_inv)
        return mx


def spm_to_tensor(sparse_mx):
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(np.vstack(
            (sparse_mx.row, sparse_mx.col))).long()
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def get_all_words(graph_data):
    all_words = []
    for concept, adj_list in graph_data.items():
        all_words += get_individual_words(concept)
        for adj_concept in adj_list.keys():
            # get individual words
            all_words += get_individual_words(adj_concept)
    
    return all_words


def get_all_concepts(graph_data):
    all_concepts = []
    for concept, adj_list in graph_data.items():
        all_concepts.append(concept)
        for adj_concept in adj_list.keys():
            all_concepts.append(adj_concept)
    
    return all_concepts


def get_all_relations(graph_data):
    all_relations = []
    for concept, adj_list in graph_data.items():
        for adj_concept, adj_data in adj_list.items():
            all_relations.append(adj_data['relation']['uri'])

    return all_relations


def get_all_directed_relations(graph_data):
    all_directed_relations = []
    for concept, adj_concept in graph_data.items():
        for neigh, neigh_data in adj_concept.items():
            all_directed_relations.append(get_directed_relation(neigh_data))
    
    return all_directed_relations


def concept_to_words(idx_to_conceptnet):
    list_of_token_list = []
    for i in range(len(idx_to_conceptnet)):
        list_of_token_list.append(get_individual_words(idx_to_conceptnet[i]))
    
    return list_of_token_list


def get_individual_words(concept):
    clean_concepts = re.sub(r"\/c\/[a-z]{2}\/|\/.*", "", concept)
    return clean_concepts.strip().split("_")


def get_directed_relation(rel_data, directed=True): 
    # print(rel_data) 
    if rel_data['relation']['directed'] and directed: 
        return rel_data['edge_type'] + rel_data['relation']['uri'] 
    else: 
        return rel_data['relation']['uri']


def filter_rel(graph_json, relations):
    filtered_dict = {}
    for concept, adj_list in graph_json.items():
        filtered_dict[concept] = {}
        for adj_concept, data in adj_list.items():
            if data['relation']['uri'] in relations:
                filtered_dict[concept][adj_concept] = data

    return filtered_dict


def filter_verbs(graph_json):
    filtered_dict = {}
    for concept, adj_list in graph_json.items():
        filtered_dict[concept] = {}
        for adj_concept, data in adj_list.items():
            if not re.search(r'\/v\/|\/v$', adj_concept):
                # creating an empty dict
                filtered_dict[concept][adj_concept] = data

    return filtered_dict


def filter_concepts(graph_json):
    """The function filters the concepts that are only in english;
    This function is not used anymore but can be used to check verify
    the function at a later point. 
    
    Arguments:
        graph_json {dict} -- contains the entire graph
    
    Returns:
        dict -- filtered dict containing only english concepts
    """
    filtered_dict = {}
    for concept, adj_list in graph_json.items():
        if re.search(r'\/c\/en\/', concept):
            filtered_dict[concept] = {}
        else:
            continue

        for adj_concept, data in adj_list.items():
            if re.search(r'\/c\/en\/', adj_concept):
                # creating an empty dict
                filtered_dict[concept][adj_concept] = data

    return filtered_dict


def convert_token_to_idx(list_of_token_list, token_to_idx):
    """The code convert list of string tokens to its ids and returns
    the tensor.
    
    Arguments:
        list_of_token_list {list} -- list of list containing token strings
        token_to_idx {dict} -- token to id mapping from description vocab
    
    Returns:
        torch.tensor -- the tensor with the ids
    """
    # pad the tokens as well
    max_length = max([len(tokens) for tokens in list_of_token_list])
    
    token_idx_list = []

    for tokens in list_of_token_list:
        tokens_idx = [token_to_idx[token] for token in tokens]
        tokens_idx += [token_to_idx['@@PADDING@@']] * (max_length - len(tokens_idx))
        token_idx_list.append(tokens_idx)
    
    token_tensor = torch.tensor(token_idx_list)

    return token_tensor


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # if you are using GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def mask_l2_loss(a, b, mask):
    return l2_loss(a[mask], b[mask])


def l2_loss(a, b):
    return ((a - b)**2).sum() / (len(a) * 2)


def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        cuda_device = 0
    else:
        device = torch.device('cpu')
        cuda_device = -1
    return device, cuda_device


def convert_index_to_int(adj_lists):
    """Function to convert the node indices to int
    """
    new_adj_lists = {}
    for node, neigh in adj_lists.items():
        new_adj_lists[int(node)] = neigh
    
    return new_adj_lists


def save_model(model, save_path):
    """The function is used to save the model

    Arguments:
        model {nn.Model} -- the model
        save_path {str} -- model save path
    """
    torch.save(model.state_dict(), save_path)


def setup_graph(graph_path, collapse=False):
    # load the adj list
    if collapse:
        pass
    else:
        adj_lists_path = os.path.join(graph_path, 'rw_adj_rel_lists.json')

    adj_lists = json.load(open(adj_lists_path))
    # fix this 
    adj_lists = convert_index_to_int(adj_lists)

    # load conceptnet embs
    concept_path = os.path.join(graph_path, 'concepts.pt')
    concept_embs = torch.load(concept_path)

    return adj_lists, concept_embs


def get_rel_ids(adj_lists, neigh_sizes, node_ids):
    """Function to get all the rel ids

    Arguments:
        adj_lists {dict} -- dictionary containing list of list
        neigh_sizes {list} -- list containing the sample size of the neighbours
        node_ids {list} -- contains the initial train ids

    Returns:
        set -- returns the set of relations that are part of the training
    """
    all_rels = []
    nodes = node_ids
    for sample_size in neigh_sizes:
        to_neighs = [adj_lists[node] for node in nodes]
        _neighs = [sorted(to_neigh, key=lambda x: x[2], reverse=True)[:sample_size] 
                        if len(to_neigh) >= sample_size else to_neigh for to_neigh in to_neighs]
        _node_rel = []
        nodes = []
        for neigh in _neighs:
            for node, rel, hp in neigh:
                all_rels.append(rel)
                nodes.append(node)  
  
    all_rels = set(all_rels)
    return all_rels


def prune_graph(adj_lists, relations):
    """The function is used to prune graph based on the relations 
    that are present in the training

    Arguments:
        adj_lists {dict} -- dictionary containing the graph
        relations {set} -- list of relation ids

    Returns:
        dict -- pruned graph
    """
    pruned_adj_list = {}
    for node, adj in adj_lists.items():
        pruned_adj_list[node] = []
        for neigh_node, rel, hp in adj:
            if rel in relations:
                pruned_adj_list[node].append((neigh_node, rel, hp))
    
    return pruned_adj_list
