from inspect import signature
import numpy as np
import torch


# ----- START OF SYMBASE DEFINITION -----

####### Scalability, 4 variables

inf = 1e10
tc_inf = torch.tensor(inf).float()

def make_safe_un(op):
    def new_op(x):
        ans = op(x)
        ans = np.minimum(ans, inf)
        ans = np.maximum(ans, -inf)
        ans = np.nan_to_num(ans, nan=0.0)
        return ans
    return new_op

def tc_make_safe_un(op):
    def new_op(x):
        ans = op(x)
        ans = torch.minimum(ans, tc_inf)
        ans = torch.maximum(ans, -tc_inf)
        ans = torch.nan_to_num(ans, nan=0.0)
        return ans
    return new_op

def make_safe_bin(op):
    def new_op(x, y):
        ans = op(x, y)
        ans = np.minimum(ans, inf)
        ans = np.maximum(ans, -inf)
        ans = np.nan_to_num(ans, nan=0.0)
        return ans
    return new_op

def tc_make_safe_bin(op):
    def new_op(x, y):
        ans = op(x, y)
        ans = torch.minimum(ans, tc_inf)
        ans = torch.maximum(ans, -tc_inf)
        ans = torch.nan_to_num(ans, nan=0.0)
        return ans
    return new_op

def _inv(x):
    ans = np.divide(1.0, x)
    return ans

def _tc_inv(x):
    ans = torch.divide(1.0, x)
    return ans

def _div(x, y):
    ans = np.divide(x, y)
    if np.allclose(x, y):
        ans = 0.0 * ans + 1.0
    return ans

def _tc_div(x, y):
    ans = torch.divide(x, y)
    if torch.allclose(x, y):
        ans = 0.0 * ans + 1.0
    return ans

def _mul(x, y):
    ans = x * y
    return ans

def _tc_mul(x, y):
    ans = x * y
    return ans

def _square(x):
    ans = x ** 2
    return ans

def _tc_square(x):
    ans = x ** 2
    return ans

def _cube(x):
    ans = x ** 3
    return ans

def _tc_cube(x):
    ans = x ** 3
    return ans

def _exp(x):
    ans = np.exp(x)
    return ans

def _tc_exp(x):
    ans = torch.exp(x)
    return ans

def _ln(x):
    ans = np.log(x)
    return ans

def _tc_ln(x):
    ans = torch.log(x)
    return ans

def _sqrt(x):
    x = np.maximum(x, 0.0)
    ans = x ** 0.5
    return ans

def _tc_sqrt(x):
    x = torch.maximum(
        x,
        torch.tensor(0.0).float()
    )
    ans = x ** 0.5
    return ans

def _identity(x):
    return x

# Format: (symbol, eval function, [commutative], latex function, pytorch eval function)

special_symbol = "x"
special_parameter_symbol = "CONST"

_constants = [
    ####### Singles VIII (??)
    # ("CONST", lambda: 1.109801, lambda: "\\Box", lambda: 1.109801 + np.random.uniform(-0.5, 0.5)),
    ####### Singles IX, X, XI
    ("CONST", lambda: 1.109801, lambda: "\\Box", lambda: 1.109801 + np.random.uniform(-0.1, 0.1)),
    ####### Singles V
    # ("CONST", lambda: 1.109801, lambda: "\\Box", lambda: 1.109801),
]

_variables = [
    ("x", _identity, lambda: "x", _identity),
    ("y", _identity, lambda: "y", _identity),
    ("z", _identity, lambda: "z", _identity),
    ("w", _identity, lambda: "w", _identity),
]

_binops = [
    ("+", lambda x, y: x + y, lambda x, y: f"({x} + {y})", lambda x, y: x + y),
    ("-", lambda x, y: x - y, lambda x, y: f"({x} - {y})", lambda x, y: x - y,),
    ("*", make_safe_bin(_mul), lambda x, y: f"({x} \\cdot {y})", tc_make_safe_bin(_tc_mul)),
    ("/", make_safe_bin(_div), lambda x, y: "\\frac{" + str(x) + "}" + "{" + str(y) + "}", tc_make_safe_bin(_tc_div)),
]

_unops = [
    ## ("-", lambda x: -x, lambda x: f"-({x})", lambda x: -x),
    ## ("inv", make_safe_un(_inv), lambda x: f"({x})^{{-1}}", tc_make_safe_un(_tc_inv)),
    ("square", make_safe_un(_square), lambda x: f"({x})^2", tc_make_safe_un(_tc_square)),
    ## ("cube", make_safe_un(_cube), lambda x: f"({x})^3"),
    ("sqrt", make_safe_un(_sqrt), lambda x: "\\sqrt{" + str(x) + "}", tc_make_safe_un(_tc_sqrt)),
    ("sin", lambda x: np.sin(x), lambda x: "\\sin (" + str(x) + ")", lambda x: torch.sin(x)),
    ("cos", lambda x: np.cos(x), lambda x: "\\cos (" + str(x) + ")", lambda x: torch.cos(x)),
    ("exp", make_safe_un(_exp), lambda x: "e^{" + str(x) + "}", tc_make_safe_un(_tc_exp)),
]

_nesting_rules = {
    # "inv": 2,
    "/": 2,
    "square": 2,
    "sqrt": 2,
    "sin": 1,
    "cos": 1,
    "exp": 1,
    # "sin": (("sin", "cos"), 1),
}

_zero_arity_rules = {
    "CONST": 2,
    "y": 2,
    "z": 2,
    "w": 2,
}

_complexity_rules = {
    "+": {
        **_zero_arity_rules,
        # "+": 2,
        # "*": 3,
        # "-": 3,
        # "/": 3,
    },
    "-": {
        **_zero_arity_rules,
        # "+": 3,
        # "*": 3,
        # "-": 2,
        # "/": 3,
    },
    "*": {
        **_zero_arity_rules,
        # "+": 3,
        # "*": 2,
        # "-": 3,
        # "/": 3,
    },
    "/": {
        **_zero_arity_rules,
        # "+": 3,
        # "*": 3,
        # "-": 3,
        # "/": 2,
    },
    "sin": {
        **_zero_arity_rules,
        "sin": 0,
        "cos": 0,
        "exp": 0,
        # "square": 1,
        # "sqrt": 1,
        "square": 0,
        "sqrt": 0,
    },
    "cos": {
        **_zero_arity_rules,
        "sin": 0,
        "cos": 0,
        "exp": 0,
        # "square": 1,
        # "sqrt": 1,
        "square": 0,
        "sqrt": 0,
    },
    "exp": {
        **_zero_arity_rules,
        # "sin": 1,
        # "cos": 1,
        "sin": 0,
        "cos": 0,
        "exp": 0,
        "square": 1,
        # "sqrt": 1,
        "sqrt": 0,
    },
    "sqrt": {
        **_zero_arity_rules,
        # "sqrt": 1,
        "sqrt": 0,
    },
    "square": {
        **_zero_arity_rules,
        # "square": 1,
        "square": 0,
    }
}

_inversions = {
    "sin": lambda x: np.arcsin(x),
    "cos": lambda x: np.arccos(x),
    "exp": make_safe_un(_ln),
    "sqrt": make_safe_un(_square),
    "square": make_safe_un(_sqrt),
    "+": [
        lambda ans, x: ans - x,
        lambda ans, x: ans - x,
    ],
    "*": [
        make_safe_bin(_div),
        make_safe_bin(_div),
    ],
    "-": [
        lambda ans, x: x - ans,
        lambda ans, x: ans + x,
    ],
    "/": [
        lambda ans, x: make_safe_bin(_div)(x, ans),
        lambda ans, x: make_safe_bin(_mul)(ans, x),
    ],
}

# ----- END OF SYMBASE DEFINITION -----

_var_vocab = [elem[0] for elem in _variables]
assert special_symbol in _var_vocab

_pos_ptrs = [0]
for _sym in [_constants, _variables, _binops]:
    _pos_ptrs.append(_pos_ptrs[-1] + len(_sym))

_symbase = _constants + _variables + _binops + _unops
vocab = [elem[0] for elem in _symbase]
functions = [elem[1] for elem in _symbase]
tc_functions = [elem[3] for elem in _symbase]

#######
# latexes = [elem[-1] for elem in _symbase]
latexes = [elem[2] for elem in _symbase]
#######

assert special_symbol in vocab
assert len(vocab) == len(set(vocab))

bijection = {}
for idx, token in enumerate(vocab):
    bijection[idx] = token
    bijection[token] = idx

#######
nesting_rules = {
    bijection[op]: lvls
    for op, lvls in _nesting_rules.items()
}
complexity_rules = {}
for op, rule in _complexity_rules.items():
    # complexity_rules[bijection[op]] = np.zeros(len(vocab))
    complexity_rules[bijection[op]] = np.full(len(vocab), inf)
    for subop, val in rule.items():
        complexity_rules[bijection[op]][bijection[subop]] = val
#######
inversions = {}
for op, inv in _inversions.items():
    inversions[bijection[op]] = inv
#######

special_id = None
for _id, symbol in enumerate(vocab):
    if symbol == special_symbol:
        special_id = _id
        break

assert special_id is not None

special_parameter_id = None
for _id, symbol in enumerate(vocab):
    if symbol == special_parameter_symbol:
        special_parameter_id = _id
        break

assert special_parameter_id is not None

for _, ctt, _, _ in _constants:
    assert len(signature(ctt).parameters) == 0

for _, var, _, _ in _variables:
    assert len(signature(var).parameters) == 1

for _, op, _, _ in _binops:
    assert len(signature(op).parameters) == 2

for _, op, _, _ in _unops:
    assert len(signature(op).parameters) == 1

# Constants
constants = list(range(_pos_ptrs[0], _pos_ptrs[1]))

# Variables
variables = list(range(_pos_ptrs[1], _pos_ptrs[2]))

zops = constants + variables

bops = list(range(_pos_ptrs[2], _pos_ptrs[3]))

uops = list(range(_pos_ptrs[3], len(vocab)))

# Arities
arities = [len(signature(op).parameters) for op in functions]
for idx in variables:
    arities[idx] = 0

# cops = []
# for idx in bops:
#     if _symbase[idx][2]:
#        cops.append(idx) 
