"""Basics/common utilities."""
import tensorflow as tf


def to_np(x):
    if isinstance(x, tf.Tensor):
        return x.numpy()
    return x


def to_deep_tuple(t):
    # Check for iterability.
    try:
        iter(t)
    except TypeError:
        return t
    return tuple(to_deep_tuple(s) for s in t)


def to_deep_frozenset(t):
    # Check for iterability.
    try:
        iter(t)
    except TypeError:
        return t
    return frozenset(to_deep_frozenset(s) for s in t)


###############################################################################


def homogenize(X, ones=None):
    if ones is None:
        ones = tf.ones([*X.shape[:-1], 1], dtype=X.dtype)
    return tf.concat([X, ones], axis=-1)


def fit_affine(X, Y):
    Xh = homogenize(X)

    # TODO: See if this is the fastest/most stable way.
    # tf.solve might be a good alternative when Xh is full rank.
    output = tf.linalg.lstsq(
        Xh, Y[..., None], fast=False
    )
    # output = tf.linalg.solve(Xh, Y[:, None])
    output = tf.squeeze(output, axis=-1)
    w, b = output[..., :-1], output[..., -1]
    return w, b
