R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ead/by_layer_metrics002.py

"""
import collections
from importlib import reload
import itertools
import os
import time
from typing import Sequence

import numpy as np
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
import matplotlib.pyplot as plt

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm

from em.datasets.antiderivative import antiderivative_ds
from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util.color_util import cu
from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf
from em.projects.ead import ead_misc1 as eadm
from em.util import vat_da_faak_vpn

from em.datasets.antiderivative import expression_metadata as emd

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ead1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'


# MODEL = "best_base_ead_infix_ds5s01_35k_dev001"
# PER_EXAMPLES_FISHERS = f"{MODEL}.ds5s01.no_embeddings.sparse_dynamic_raw.16k.16k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"

MODEL = "best_base_ead_infix_150k_dev002"
PER_EXAMPLES_FISHERS = f"{MODEL}.val.no_embeddings.sparse_dynamic_raw.16k.16k.h5"
DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"


FROM_PT = False

N_DECOMPS = 25

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

###############################################################################

print('Starting to load saved per-example Fishers.')
start = time.time()
pef = per_example.PerExampleFlatFishers.load(
    os.path.expanduser(os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS)),
    n_examples=16 * 1024,
    # This leads to the Fishers not being loaded, which ends up being much faster.
    start_fisher_index=0,
    end_fisher_index=0,
)
print('Load saved per-example Fishers time: ', time.time() - start)

nmfs = am._LazyNmfList(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP_FILENAME), n_nmfs=N_DECOMPS)

container = eadm.EadAnalysisContainer(
    pef=pef,
    nmfs=nmfs,
    tokenizer=tokenizer,
)

container.nmfs.force_load_all()

###############################################################################


def to_exprs(container):
    n_examples = container.pef.input_ids.shape[0]
    exprs = []
    for index in range(n_examples):
        expr_str = container.get_example_as_string(index)
        try:
            expr = sp.sympify(expr_str)
        except Exception as e:
            print(cu.hlr(e))
            expr = None
        exprs.append(expr)
    return exprs


def indicator_to_inds(indicator, value=True):
    inds = []
    for i, val in enumerate(indicator):
        if val is None:
            continue
        if val == value:
            inds.append(i)
    return inds


start = time.time()
exprs = to_exprs(container)
print(time.time() - start)

x = sp.sympify('x')

is_polynomials = []
for expr in tqdm(exprs):
    # print(expr)
    if expr is None:
        is_polynomials.append(None)
    else:
        # is_polynomials.append(expr.is_polynomial(x))
        is_polynomials.append(expr.is_rational_function(x))


poly_inds = indicator_to_inds(is_polynomials)

fraction_per_subset = container.compute_mass_fraction_by_subset()
fraction_per_subset_polys = container.compute_mass_fraction_by_subset(poly_inds)


plt.plot(np.log(fraction_per_subset))
plt.plot(np.log(fraction_per_subset_polys))
plt.show()

