R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/ead/supervised_transfer/matching_dev001.py

"""
import collections
from importlib import reload
import itertools
import multiprocessing as mp
import os
import time
from typing import Sequence

import numpy as np
from scipy import special
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from tqdm import tqdm

from em import datasets as em_datasets
from em.datasets.antiderivative import antiderivative_ds
from em.datasets.antiderivative import misc_util
from em.datasets.antiderivative.expression_metadata import matching as M
from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util.color_util import cu
from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf
from em.projects.ead import ead_misc1 as eadm
from em.util import latex_util
from em.util import vat_da_faak_vpn

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ead1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
MODEL = "best_base_ead_infix_75k_dev003"

# DATASET = 'ead_ds_002.validation'
DATASET = 'ead_ds_002.train'

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)


sequence_length = 128
batch_size = 128

###############################################################################


def get_example_as_string(tokens) -> str:
    example = tokenizer.decode(tokens)
    example = example.replace(tokenizer.pad_token, '')
    example = example.replace(tokenizer.cls_token, '')
    example = example.replace(tokenizer.sep_token, '')
    example = example.replace(' ', '')
    return example


@misc_util.timeout(1)
def _sympify(s):
    return sp.sympify(s)


def to_expr(s):
    try:
        return _sympify(s)
    except (Exception, BaseException):
        return None

###############################################################################


ds = em_datasets.load(
    'ead/infix',
    split=DATASET,
    tokenizer=tokenizer,
    sequence_length=sequence_length,
)

ds = list(tqdm(ds.batch(batch_size)))

input_ids_batches = [b[0]['input_ids'] for b in ds]
label_batches = [b[1].numpy() for b in ds]

example_texts = [
    get_example_as_string(example_ids)
    for batch in tqdm(input_ids_batches)
    for example_ids in batch.numpy()
]

example_exprs = [
    to_expr(ex) for ex in tqdm(example_texts)
]
example_exprs = [e for e in example_exprs if e is not None]

# matcher = M.AnyUntil(M.Function(sp.exp, M.AnyUntil(M.Function(sp.exp))))
# matcher = M.RingOpsUntil(M.Function(sp.exp, M.RingOpsUntil(M.Function(sp.exp))))

matcher = M.RingOpsUntil(
    M.Function(
        sp.exp,
        M.Add([
            M.RingOpsUntil(M.Literal('x')),
            M.Function(
                sp.exp,
                M.RingOpsUntil(M.Literal('x')),
            )
        ])
    )
)

x = sp.sympify('x')

matched_exprs = [expr for expr in tqdm(example_exprs) if matcher.match(expr, x)]
print(len(matched_exprs))
