from inspect import signature
import numpy as np
import torch


# ----- START OF SYMBASE DEFINITION -----

####### Feynman

inf = 1e10
tc_inf = torch.tensor(inf).float()

def make_safe_un(op):
    def new_op(x):
        ans = op(x)
        ans = np.minimum(ans, inf)
        ans = np.maximum(ans, -inf)
        ans = np.nan_to_num(ans, nan=0.0)
        return ans
    return new_op

def tc_make_safe_un(op):
    def new_op(x):
        ans = op(x)
        ans = torch.minimum(ans, tc_inf)
        ans = torch.maximum(ans, -tc_inf)
        ans = torch.nan_to_num(ans, nan=0.0)
        return ans
    return new_op

def make_safe_bin(op):
    def new_op(x, y):
        ans = op(x, y)
        ans = np.minimum(ans, inf)
        ans = np.maximum(ans, -inf)
        ans = np.nan_to_num(ans, nan=0.0)
        return ans
    return new_op

def tc_make_safe_bin(op):
    def new_op(x, y):
        ans = op(x, y)
        ans = torch.minimum(ans, tc_inf)
        ans = torch.maximum(ans, -tc_inf)
        ans = torch.nan_to_num(ans, nan=0.0)
        return ans
    return new_op

def _div(x, y):
    ans = np.divide(x, y)
    if np.allclose(x, y):
        ans = 0.0 * ans + 1.0
    return ans

def _tc_div(x, y):
    ans = torch.divide(x, y)
    if torch.allclose(x, y):
        ans = 0.0 * ans + 1.0
    return ans

def _mul(x, y):
    ans = x * y
    return ans

def _tc_mul(x, y):
    ans = x * y
    return ans

def _square(x):
    ans = x ** 2
    return ans

def _tc_square(x):
    ans = x ** 2
    return ans

def _identity(x):
    return x

# Format: (symbol, eval function, [commutative], latex function, pytorch eval function)

special_symbol = "x"
special_parameter_symbol = "CONST"

_constants = [
    ####### Singles VIII (??)
    # ("CONST", lambda: 1.109801, lambda: "\\Box", lambda: 1.109801 + np.random.uniform(-0.5, 0.5)),
    ####### Singles IX, X, XI
    ("CONST", lambda: 1.109801, lambda: "\\Box", lambda: 1.109801 + np.random.uniform(-0.1, 0.1)),
    ####### Singles V
    # ("CONST", lambda: 1.109801, lambda: "\\Box", lambda: 1.109801),
]

_variables = [
    ("x", _identity, lambda: "x", _identity),
    ("y", _identity, lambda: "y", _identity),
]

_binops = [
    ("+", lambda x, y: x + y, lambda x, y: f"({x} + {y})", lambda x, y: x + y),
    ("-", lambda x, y: x - y, lambda x, y: f"({x} - {y})", lambda x, y: x - y,),
    ("*", make_safe_bin(_mul), lambda x, y: f"({x} \\cdot {y})", tc_make_safe_bin(_tc_mul)),
    ("/", make_safe_bin(_div), lambda x, y: "\\frac{" + str(x) + "}" + "{" + str(y) + "}", tc_make_safe_bin(_tc_div)),
]

_unops = [
    ("square", make_safe_un(_square), lambda x: f"({x})^2", tc_make_safe_un(_tc_square)),
    ("sin", lambda x: np.sin(x), lambda x: "\\sin (" + str(x) + ")", lambda x: torch.sin(x)),
    ("cos", lambda x: np.cos(x), lambda x: "\\cos (" + str(x) + ")", lambda x: torch.cos(x)),
]

_nesting_rules = {
    "/": 2,
    "square": 2,
    "sqrt": 2,
    "sin": 1,
    "cos": 1,
    "exp": 1,
}

_zero_arity_rules = {
    "CONST": 2,
    "y": 2,
}

_complexity_rules = {
    "+": {
        **_zero_arity_rules,
    },
    "-": {
        **_zero_arity_rules,
    },
    "*": {
        **_zero_arity_rules,
    },
    "/": {
        **_zero_arity_rules,
    },
    "sin": {
        **_zero_arity_rules,
        "sin": 0,
        "cos": 0,
        "exp": 0,
        # "square": 1,
        # "sqrt": 1,
        "square": 0,
        "sqrt": 0,
    },
    "cos": {
        **_zero_arity_rules,
        "sin": 0,
        "cos": 0,
        "exp": 0,
        # "square": 1,
        # "sqrt": 1,
        "square": 0,
        "sqrt": 0,
    },
    "exp": {
        **_zero_arity_rules,
        # "sin": 1,
        # "cos": 1,
        "sin": 0,
        "cos": 0,
        "exp": 0,
        "square": 1,
        # "sqrt": 1,
        "sqrt": 0,
    },
    "sqrt": {
        **_zero_arity_rules,
        # "sqrt": 1,
        "sqrt": 0,
    },
    "square": {
        **_zero_arity_rules,
        # "square": 1,
        "square": 0,
    }
}

# ----- END OF SYMBASE DEFINITION -----

_var_vocab = [elem[0] for elem in _variables]
assert special_symbol in _var_vocab

_pos_ptrs = [0]
for _sym in [_constants, _variables, _binops]:
    _pos_ptrs.append(_pos_ptrs[-1] + len(_sym))

_symbase = _constants + _variables + _binops + _unops
vocab = [elem[0] for elem in _symbase]
functions = [elem[1] for elem in _symbase]
tc_functions = [elem[3] for elem in _symbase]

###
# latexes = [elem[-1] for elem in _symbase]
latexes = [elem[2] for elem in _symbase]
###

assert special_symbol in vocab
assert len(vocab) == len(set(vocab))

bijection = {}
for idx, token in enumerate(vocab):
    bijection[idx] = token
    bijection[token] = idx

###
nesting_rules = {
    bijection[op]: lvls
    for op, lvls in _nesting_rules.items()
}
complexity_rules = {}
for op, rule in _complexity_rules.items():
    # complexity_rules[bijection[op]] = np.zeros(len(vocab))
    complexity_rules[bijection[op]] = np.full(len(vocab), inf)
    for subop, val in rule.items():
        complexity_rules[bijection[op]][bijection[subop]] = val
###

special_id = None
for _id, symbol in enumerate(vocab):
    if symbol == special_symbol:
        special_id = _id
        break

assert special_id is not None

special_parameter_id = None
for _id, symbol in enumerate(vocab):
    if symbol == special_parameter_symbol:
        special_parameter_id = _id
        break

assert special_parameter_id is not None

for _, ctt, _, _ in _constants:
    assert len(signature(ctt).parameters) == 0

for _, var, _, _ in _variables:
    assert len(signature(var).parameters) == 1

for _, op, _, _ in _binops:
    assert len(signature(op).parameters) == 2

for _, op, _, _ in _unops:
    assert len(signature(op).parameters) == 1

# Constants
constants = list(range(_pos_ptrs[0], _pos_ptrs[1]))

# Variables
variables = list(range(_pos_ptrs[1], _pos_ptrs[2]))

zops = constants + variables

bops = list(range(_pos_ptrs[2], _pos_ptrs[3]))

uops = list(range(_pos_ptrs[3], len(vocab)))

# Arities
arities = [len(signature(op).parameters) for op in functions]
for idx in variables:
    arities[idx] = 0

# cops = []
# for idx in bops:
#     if _symbase[idx][2]:
#        cops.append(idx) 
