R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ead/ead_comps001.py

"""
import collections
from importlib import reload
import itertools
import os
import time
from typing import Sequence

import numpy as np
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.datasets.antiderivative import antiderivative_ds
from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util.color_util import cu
from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf
from em.projects.ead import ead_misc1 as eadm
from em.util import vat_da_faak_vpn


###############################################################################

def make_symbols_dict(symbols: Sequence[str]):
    return collections.OrderedDict(
        # TODO: Allow a way to specify those for which nonzero=False.
        (s, sp.Symbol(s, real=True, nonzero=True))
        for s in symbols
    )

###############################################################################


EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ead1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')
###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'


MODEL = "best_base_ead_infix_ds5s01_35k_dev001"
PER_EXAMPLES_FISHERS = f"{MODEL}.ds5s01.no_embeddings.sparse_dynamic_raw.16k.16k.h5"
DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"


FROM_PT = False

N_DECOMPS = 25

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

###############################################################################

print('Starting to load saved per-example Fishers.')
start = time.time()
pef = per_example.PerExampleFlatFishers.load(
    os.path.expanduser(os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS)),
    n_examples=16 * 1024,
    # This leads to the Fishers not being loaded, which ends up being much faster.
    start_fisher_index=0,
    end_fisher_index=0,
)
print('Load saved per-example Fishers time: ', time.time() - start)

nmfs = am._LazyNmfList(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP_FILENAME), n_nmfs=N_DECOMPS)

container = eadm.EadAnalysisContainer(
    pef=pef,
    nmfs=nmfs,
    tokenizer=tokenizer,
)

container.nmfs.force_load_all()

###############################################################################

# reload(eadm); container.__class__ = eadm.EadAnalysisContainer


incorrects = container.get_incorrect_prediction_indicator()
tuned_comp_infos = ncf.get_components_appearing_tuned(
    container,
    indicator=incorrects,
    #
    coeff_factor=0.5,
    frac_threshold=0.25,
    p_value_threshold=0.01,
)

# component_indices_per_subset = [
#     list(sorted(c.component_index for c in correct_component_infos))
#     for correct_component_infos in ncf.group_by_nmf(container, tuned_comp_infos)
# ]

example_indices_per_component_per_subset = [
    {
        c.component_index: c.get_all_relevant_example_indices(container)
        for c in correct_component_infos
    }
    for correct_component_infos in ncf.group_by_nmf(container, tuned_comp_infos)
]

# 1.15.93 Component 92

###############################################################################
# # nmf = nmfs[18]

# symbols = make_symbols_dict(['x'])

# # reload(eadm); container.__class__ = eadm.EadAnalysisContainer


# for i, nmf in enumerate(container.nmfs):
#     if np.any(~np.isfinite(nmf.W)):
#         print(cu.hlb(i))


# def print_top_examples(nmf_index: int, component_index: int, n_examples: int):
#     top_inds = container.get_top_example_indices(nmf_index, component_index, n_examples)
#     for i in top_inds:
#         container.print_example_for_component(nmf_index, component_index, i)


# print_top_examples(n_examples=16, nmf_index=18, component_index=0)

# print_top_examples(n_examples=16, nmf_index=16, component_index=5)  # Sort of
# print_top_examples(n_examples=16, nmf_index=16, component_index=6)
# print_top_examples(n_examples=16, nmf_index=16, component_index=7)

# print_top_examples(n_examples=16, nmf_index=16, component_index=11)  # For only the top 2.

# print_top_examples(n_examples=16, nmf_index=16, component_index=16)

# print_top_examples(n_examples=16, nmf_index=16, component_index=18)

# print_top_examples(n_examples=16, nmf_index=20, component_index=2)

# print_top_examples(n_examples=16, nmf_index=6, component_index=6)

# print_top_examples(n_examples=16, nmf_index=6, component_index=224)

# print_top_examples(n_examples=16, nmf_index=6, component_index=169)


# NMF index of 14 looks like it has a lot of good components.

# print_top_examples(n_examples=16, nmf_index=14, component_index=66)
# print_top_examples(n_examples=16, nmf_index=14, component_index=67)
# print_top_examples(n_examples=16, nmf_index=14, component_index=68)
# print_top_examples(n_examples=16, nmf_index=14, component_index=76)
# print_top_examples(n_examples=16, nmf_index=14, component_index=77)
# print_top_examples(n_examples=16, nmf_index=14, component_index=83)

"""
Decomp by top-level sum into terms (probably with no simplifications).

"""
