import torch
import numpy as np


def gumbel_softmax(logits, tau=1.0, beta=1.0, hard=False, dim=-1):
    noise = -torch.empty_like(
        logits, memory_format=torch.legacy_contiguous_format)
    gumbels = noise.exponential_().log()
    gumbels = logits + gumbels*beta
    gumbels = gumbels / tau
    m = torch.nn.Softmax(dim)
    y_soft = m(gumbels)
    if hard:
        index = y_soft.max(dim, keepdim=True)[1]
        zeroes = torch.zeros_like(
            logits, memory_format=torch.legacy_contiguous_format)
        y_hard = zeroes.scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        ret = y_soft
    return ret

def linear_combo_features(input_data, state_size):
    difference_features = input_data[np.newaxis, :, :state_size, None] > input_data[:, np.newaxis, :state_size]
    difference_features = difference_features.reshape(len(input_data), -1).astype(int)
    return np.concatenate((input_data, difference_features), axis=1)


def get_linear_features_used(tree, num_features):
    linear_features_used = []
    linear_importances = tree.feature_importances_[num_features*2:]
    for idx, importance in enumerate(linear_importances):
        if importance == 0:
            continue

        x1 = idx // num_features
        x2 = idx % num_features
        linear_features_used.append(f'x{x1} > x{x2}')
    return linear_features_used


def normalize(data):
    normalized = np.array(data)
    if normalized.sum() != 0:
        normalized /= normalized.sum()
    return normalized
