R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ead/supervised_transfer/wrong_preds_latex001.py
CUDA_VISIBLE_DEVICES=1 python -i local_scripts/ead/supervised_transfer/wrong_preds_latex001.py

"""
import collections
from importlib import reload
import itertools
import os
import time
from typing import Sequence

import numpy as np
from scipy import special
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from tqdm import tqdm

from em import datasets as em_datasets
from em.datasets.antiderivative import antiderivative_ds
from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util.color_util import cu
from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf
from em.projects.ead import ead_misc1 as eadm
from em.util import latex_util
from em.util import vat_da_faak_vpn

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ead1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
MODEL = "best_base_ead_infix_75k_dev003"

# DATASET = 'ead_ds_002.validation'
DATASET = 'ead_ds_002.train'

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=False)

sequence_length = 128
batch_size = 128

###############################################################################

ds = em_datasets.load(
    'ead/infix',
    split=DATASET,
    tokenizer=tokenizer,
    sequence_length=sequence_length,
)

ds = list(tqdm(ds.batch(batch_size)))

input_ids_batches = [b[0]['input_ids'] for b in ds]
label_batches = [b[1].numpy() for b in ds]
logits_batches = [model(b[0], training=False).logits.numpy() for b in tqdm(ds)]

labels = np.concatenate(label_batches, axis=0)
predicted_logits = np.concatenate(logits_batches, axis=0)
input_ids = np.concatenate(input_ids_batches, axis=0)


###############################################################################

def _to_latex_label(label: int) -> str:
    # The bold requires `\usepackage{bold-extra}`
    s = eadm._LABEL_TO_PADDED_LABEL_NAME_LATEX[label]
    color = eadm._LABEL_TO_LATEX_COLOR[label]
    return R'{\color{' + color + R'}\textbf{' + s + R'}}'


def get_example_as_string(example_index: int) -> str:
    example = tokenizer.decode(input_ids[example_index])
    example = example.replace(tokenizer.pad_token, '')
    example = example.replace(tokenizer.cls_token, '')
    example = example.replace(tokenizer.sep_token, '')
    example = example.replace(' ', '')
    return example


def get_example_as_latex_string(example_index: int) -> str:
    example = get_example_as_string(example_index)
    try:
        expr = sp.sympify(example)
        return sp.latex(expr)
    except Exception:
        return ' '


def get_example_as_mathematica(example_index: int) -> str:
    example = get_example_as_string(example_index)
    try:
        expr = sp.sympify(example)
        return latex_util.escape(sp.mathematica_code(expr))
    except Exception:
        return ' '


def to_latex_string(example_index: int):
    label = labels[example_index]
    predicted_logits_ = predicted_logits[example_index]
    pyx = special.softmax(predicted_logits_)
    pred = np.argmax(predicted_logits_)
    #
    line1 = R' {} '.join([
        f'{"[LABEL]"} {_to_latex_label(label)}',
        f'{"[PRED]"} {_to_latex_label(pred)}',
    ])
    line2 = fR'P(Y|X): \{{{pyx[0]:.3f}, {pyx[1]:.3f}\}}'
    line3 = R'{\tiny\color{lightgray}' + get_example_as_mathematica(example_index) + '}'
    #
    line4 = f'$${get_example_as_latex_string(example_index)}$$'
    #
    return "\n".join([
        R'\noindent\texttt{' + line1 + R' \\',
        line2 + R' \\',
        # f'Index: {example_index}' + R' \\',
        R'{}' + line3 + R'\vspace{2mm} \\',
        R'{}' + line4 + R'\vspace{2mm} \\',
        R'}',
    ])

###############################################################################


incorrect_inds, = np.nonzero(np.argmax(predicted_logits, axis=-1) != labels)
incorrect_inds = list(sorted(incorrect_inds, key=lambda i: labels[i]))

latex_examples_strs = []
for i in incorrect_inds:
    latex_examples_strs.append(to_latex_string(i))

latex_body = '\n\n'.join(latex_examples_strs)


OUT_FILEPATH = '/fruitbasket/users/m/tmp/wrong_preds_latex_body.txt'
with open(OUT_FILEPATH, 'wt') as f:
    f.write(latex_body)

"""

rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/wrong_preds_latex_body.txt" \
    "/tmp/wrong_preds_latex_body.txt"

subl /tmp/wrong_preds_latex_body.txt

"""

# for _ in range(10):
#     print(cu.hlr(80 * '#'))

# print(2 * '\n')

# print(latex_body)
