import numpy as np
from itertools import product, chain


def disbased_weight(node, neighbors, sigma):
    dis_list = np.linalg.norm(neighbors - node, axis = 1)
    d = dis_list/sigma
    d = -d**2
    weight = np.sum(np.e**d)
    return weight

def expand_task_and_dataset(string):
    should_expand_dataset = '[DAT]' in string
    should_expand_task = '[TSK]' in string
    strings = [string]
    if should_expand_dataset:
        strings = [string.replace('[DAT]', x) for x in ['ABCD1', 'ABCD2'] for string in strings]
    if should_expand_task:
        strings = [string.replace('[TSK]', x) for x in ['Rest', 'SST', 'nBack', 'MID'] for string in strings]
    
    return strings
    
def extract_dataset_name_and_task_name(string):
    if "ABCD1" in string:
        dataset_name = "ABCD1"
    elif "ABCD2" in string:
        dataset_name = "ABCD2"
    if "Rest" in string:
        task = "Rest"
    elif "nBack" in string:
        task = "nBack"
    elif "MID" in string:
        task = "MID"
    elif "SST" in string:
        task = "SST"

    return dataset_name, task