R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/ead/supervised_transfer/model_predictor001.py

"""
from importlib import reload
import os
import time
from typing import Sequence

import numpy as np
import sympy as sp
import matplotlib.pyplot as plt


import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer


EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ead1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
# MODEL = "best_base_ead_infix_150k_dev002"
MODEL = "best_base_ead_infix_75k_dev003"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=False)

sequence_length = 128

###############################################################################


def _mutative_pad_list(x, pad_token):
    padding_length = sequence_length - len(x)
    x.extend(padding_length * [pad_token])


def _encode_single(s: str):
    s = str(sp.sympify(s))
    inputs = tokenizer.encode_plus(
        s,
        add_special_tokens=True,
        max_length=sequence_length,
        return_token_type_ids=True,
        truncation=True,
    )
    _mutative_pad_list(inputs["input_ids"], tokenizer.pad_token_id)
    _mutative_pad_list(inputs["token_type_ids"], tokenizer.pad_token_type_id)
    return inputs["input_ids"], inputs["token_type_ids"]


def encode_inputs_to_batch(inputs: Sequence[str]):
    input_ids, token_type_ids = list(zip(*(_encode_single(s) for s in inputs)))
    input_ids = tf.cast(input_ids, tf.int32)
    token_type_ids = tf.cast(token_type_ids, tf.int32)
    return {
        'input_ids': input_ids,
        'token_type_ids': token_type_ids,
    }


def get_predictions(inputs: Sequence[str]):
    batch = encode_inputs_to_batch(inputs)
    output = model(batch, training=False)
    return tf.math.softmax(output.logits)

###############################################################################
# NOTE: Beware that some of the incorrects might be the result of a distribution mismatch.
###############################################################################


# Looks like the model isn't the best at understanding log-rules.
# Some of the ways are kinda weird too.

q = get_predictions([
    'tanh(tan(77)) * x * cosh(ln(1/x)) /66',
    'tanh(tan(77)) * x * cosh(ln(1/x**tan(-4))) /66',
    'tanh(tan(77)) * x * cosh(ln(1/x**5)) /66',
    'tanh(tan(77)) * x * cosh(ln(x**(-1))) /66',
    'tanh(tan(77)) * x * cosh(ln(x)) /66',
    'tanh(tan(77)) * x * cosh(-ln(x)) /66',
])

###########################################################

# Look like there are some issues with exp-rules as well.

get_predictions([
    # Incorrect:
    '(9538 * E)**x',
    '(9538 * 4)**x',
    '9538**x * E**x',
    '9538**x * exp(x)',
    # Correct:
    '(7/5)**x',
    '9538**x',
    'E**x',
    'exp(log(9538 * E) * x)',
])

###########################################################

# Looks like multiplication by constant integer causes issues
# here. See if this is true in general for examples or just
# for a subset of them.

get_predictions([
    # Incorrect:
    '-83514543 * x**(5/2) + 83514543 * x**2  * asin(x)',
    '-51 * x**(5/2) + 51 * x**2  * asin(x)',
    '-51 * x**(5/2) + 72 * x**2  * asin(x)',
    '-2 * x**(5/2) + x**2  * asin(x)',
    '-x**(5/2) + 5 * x**2  * asin(x)',
    '5 * x**2  * asin(x) - x**(5/2)',
    'x**(5/2) + 5 * x**2  * asin(x)',
    # Correct:
    '-x**(5/2) + x**2  * asin(x)',
    '6 * x**2  * asin(x)',
    'x**2  * asin(x)',
    '-x**(5/2)',
    '-7 * x**(5/2)',
])

###########################################################

# Looks like issue with atan(x) / x^a. If a = 1, then non-elementary.
# For a bunch of positive rational values of a, it looks elementary.
# When a is irrational, then it looks possibly non-elementary.

get_predictions([
    # Incorrect:
    '2 * x + sin(x) + atan(x) / x**4 - 4162 / x**4',
    'atan(x) / x**4',
    '2 * x + sin(x) + atan(x) / x**2 - 4162 / x**4',
    # Correct:
    '2 * x + sin(x) - 4162 / x**4',
    '2 * x + sin(x) + atan(x) - 4162 / x**4',
])

###########################################################

# Looks to be issue with knowing that the expansion of x * (x * exp(...))
# has an x**2 * exp(...) in it.
get_predictions([
    # Incorrect:
    'x * (x * exp(x**3) + 1)',
    # Correct:
    'x**2 * exp(x**3)',
    'x**2 * (exp(x**3) + 1)',
])

###########################################################

get_predictions([
    # Incorrect:
    '(x - log(x) + sqrt(tan(tan(736))) * atan(x) + 9951) / sqrt(tan(tan(736)))',
    '(x - log(x) + sqrt(tan(tan(7))) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(x - log(x) + sqrt(tan(tan(16))) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(x - log(x) + sqrt(tan(7)) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(x - log(x) + tan(7) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(x - log(x) + cos(7) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(-log(x) + sqrt(tan(tan(7))) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(log(x) + sqrt(tan(tan(7))) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(exp(x) + sqrt(tan(tan(7))) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    # Correct:
    'x - log(x) + sqrt(tan(tan(736))) * atan(x) + 9951',
    '(x - log(x) + atan(x) + 9951) / sqrt(tan(tan(736)))',
    '(x - log(x) + sqrt(7) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(x + sqrt(tan(tan(7))) * atan(x) + 9951) / sqrt(tan(tan(7)))',
    '(sqrt(x) + sqrt(tan(tan(7))) * atan(x) + 9951) / sqrt(tan(tan(7)))',
])

###############################################################################
###############################################################################

# Predictions look all false. TODO: See if these all belong to similar components.
get_predictions([
    'exp(x + exp(x))',
    'exp(x + exp(-x))',
    'exp(x + exp(x/3))',
    'exp(5 * x + exp(x/3))',
    'exp(5 * x + exp(3 * x))',
    'exp(x + exp(x)) + sin(x)',
    'x + exp(x) + 3235*exp(x + exp(x)) + 12639',
])
