import numpy as np
from htssr.primitives import special_symbol
from htssr.utils import (
    fast_eval_expr,
    rolling_fast_eval_expr,
    rolling_tree_eval_expr,
)


def normalize_y(y):
    y_mu = y.mean()
    y_sigma = y.std()
    y = (y - y_mu) / (y_sigma + 1e-8)
    y_mu = np.sign(y_mu) * np.log((np.abs(y_mu) + 1.0))
    y_sigma = np.sign(y_sigma) * np.log((np.abs(y_sigma) + 1.0))
    y = np.concatenate(
        [y, np.array([y_mu, y_sigma])],
        dtype=np.float32
    )
    return y

def to_digits(y, base=10.0):
    # fdt = (base ** np.arange(-5, 6)).reshape(1, -1)
    fdt = (base ** np.arange(-10, 11)).reshape(1, -1)
    yd = y[:, np.newaxis] @ fdt
    yd = yd % base
    yd = yd.reshape(-1)
    return yd

def make_dummy(all_ids, D):
    size = len(all_ids)
    all_vals = np.zeros((size, D))
    for pos, ids in enumerate(all_ids):
        _hash = abs(hash(tuple(ids))) % (D ** 2)
        _pos0 = _hash % D
        _pos1 = _hash // D
        all_vals[pos, _pos0] = 1.0
        all_vals[pos, _pos1] = 2.0
    return all_vals

# def _make_vec(all_ids, domain, dummy=False):
#     n_points = len(domain["x"])
#     if dummy:
#         all_vals = make_dummy(all_ids, (n_points + 2))
#     else:
#         all_vals = [normalize_y(fast_eval_expr(ids, domain)) for ids in all_ids]
#         all_vals = np.array(all_vals)
#     return all_vals

def make_vals(all_ids, domain):
    # all_vals = [fast_eval_expr(ids, domain) for ids in all_ids]
    all_vals = [rolling_fast_eval_expr(ids, domain) for ids in all_ids]
    all_vals = np.array(all_vals)
    return all_vals

def make_vec_from_vals(vals):
    vals = [to_digits(vals_) for vals_ in vals]
    vals = np.array(vals)
    return vals

def make_noisy_vec(all_ids, domain, noise_level=1e-3):
    # all_vals = [fast_eval_expr(ids, domain) for ids in all_ids]
    all_vals = [rolling_fast_eval_expr(ids, domain) for ids in all_ids]
    stds = [noise_level * (np.mean(vals ** 2) ** 0.5) for vals in all_vals]
    all_vals = [vals + std * np.random.normal(loc=np.zeros(len(vals)), scale=1.0) for vals, std in zip(all_vals, stds)]
    # all_vals = [to_digits(vals) for vals in all_vals]
    all_vals = np.array(all_vals)
    return all_vals

def make_vec(all_ids, domain, dummy=False, noise_level=0):
    n_points = len(domain[special_symbol])
    # if dummy:
    #     all_vals = make_dummy(all_ids, (n_points + 2))
    # elif noise_level > 0:
    #     all_vals = make_noisy_vec(all_ids, domain, noise_level=noise_level)
    # else:
    #     all_vals = [rolling_fast_eval_expr(ids, domain) for ids in all_ids]
    #     all_vals = np.array(all_vals)
    all_vals = [rolling_fast_eval_expr(ids, domain) for ids in all_ids]
    all_vals = np.array(all_vals)
    return all_vals

def make_tree_vec(all_ids, domain, force_padding=None, pad_with=0.0):
    n_points = len(domain[special_symbol])
    all_vals = [rolling_tree_eval_expr(ids, domain) for ids in all_ids]
    max_len = max(len(vals) for vals in all_vals)
    if force_padding is not None:
        max_len = max(max_len, force_padding)
    all_vals = [
        np.pad(
            vals,
            ((0, max_len - len(vals)), (0, 0)),
            mode="constant",
            constant_values=pad_with,
        )
        for vals in all_vals
    ]
    all_vals = np.array(all_vals)
    return all_vals
