import sys
from torch_geometric.graphgym.config import cfg

__all__ = [
    'none_encoding', 'bool_encoding', 'integer_encoding', 'float_encoding', 'date_encoding', 'category_encoding',
    'text_encoding_onehot'
]


def none_encoding(items, field):
    return [[1] for _ in items]

def bool_encoding(items, field):
    return [[1 if item[field] in ["true","True","TRUE"] else 0] for item in items]

def integer_encoding(items, field):
    return [[int(item[field]) if item[field] != "" else 0] for item in items]


def float_encoding(items, field):
    return [[float(item[field]) if item[field] != "" else 0] for item in items]


def date_encoding(items, field):
    feature_vector = []
    for item in items:
        try:
            (year, month, day) = item.split("-")
            encoded_date = year * 12 * 31 + month * 31 + day
        except:
            encoded_date = -1
            try:
                (day, month, year) = item.split("/")
                encoded_date = year * 12 * 31 + month * 31 + day
            except:
                encoded_date = -1

        feature_vector.append([encoded_date])
    return feature_vector


def category_encoding(items, field):
    category_map = {}
    rst = []
    ignore_index = cfg.dataset.ignore_index
    for item in items:
        value = item[field]
        if value == '' and ignore_index is not None:
            rst.append([ignore_index])
        else:
            if item[field] not in category_map:
                category_map[value] = len(category_map)
            rst.append([category_map[item[field]]])
    return rst


def text_encoding_onehot(items,
                         field,
                         keep_most_frequent=sys.maxsize,
                         split_char=" "):
    # Count number of occurrences per word
    word_count = {}
    for item in items:
        words = item[field].split(split_char)
        for word in words:
            if word not in word_count:
                word_count[word] = 0
            word_count[word] += 1
    # Sort words according to number of occurrences, only keep most frequent ones, map them to ids
    word_map = {
        word: idx
        for (idx, word) in enumerate(
            sorted(word_count, key=lambda word: -word_count[word]))
        if idx < keep_most_frequent
    }
    # Construct feature vectors
    feature_vectors = []
    for item in items:
        value = item[field]
        feature_vector = [0 for i in range(len(word_map))]
        for word in value.split(split_char):
            if word in word_map:
                feature_vector[word_map[word]] = 1
        feature_vectors.append(feature_vector)
    return feature_vectors
