R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/tcav/bert_tcav_test001.py
"""
from importlib import reload
import os
import time

import numpy as np
from sklearn.linear_model import LogisticRegression
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.models import em_models
from em.tools.nmf import nmf_common
from em.activations import bert_activations
from em.analysis.tcav import bert_tcav

from em.util.color_util import cu

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################
MODEL = "connectivity/feather_berts_0"
TOKENIZER = "bert-base-uncased"

###############################################################################
OG_PEF_FILENAME = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
OG_NMF_FILENAME = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{OG_PEF_FILENAME}"
NMF_FILENAME = f"fit_w.skip50000.50000ex.65536vpe.{OG_NMF_FILENAME}"

CLS_ACTS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers'
CLS_ACTS_FILENAME = "feather_berts_0.snli.train_skip_50k.50000ex.bert_cls_activations.h5"

###############################################################################

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
model = em_models.from_pretrained(MODEL)

nmf = nmf_common.SparseNmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME))
acts = bert_activations.BertClsActivations.load(os.path.join(CLS_ACTS_DIR, CLS_ACTS_FILENAME))

###############################################################################

# Use only the representations from the last layer.
activations = acts.activations[:, -768:]
labels = acts.labels


#

exp = bert_tcav.BertTcavExperiment(
    activations=activations,
    labels=labels,
    model=model,
    n_negative_examples=1024,
    n_scoring_examples=5000,
)


N_TOP_EXAMPLES = 128
N_RUNS = 50

COMP_INDEX = 11

comp_exp = bert_tcav.BertTcavForComponent2(
    exp=exp,
    concept_example_indices=np.argsort(-nmf.W[:, COMP_INDEX])[:N_TOP_EXAMPLES],
    # concept_example_indices=np.arange(20000, 20000 + 128),
    n_runs=N_RUNS,
)
comp_exp.learn_cavs()
scores = comp_exp.compute_per_run_scores()

# comp_exp = bert_tcav.BertTcavForComponent(
#     exp=exp,
#     concept_example_indices=np.argsort(-nmf.W[:, COMP_INDEX])[:N_TOP_EXAMPLES]
#     # concept_example_indices=np.arange(10000, 10000 + 128,)
# )

# comp_exp.learn_cav()


# scores = comp_exp.compute_per_example_scores()


# def adsf(label):
#     mask = scores.labels[:exp.n_scoring_examples] == label
#     label_scores = scores.scores[mask, label]
#     return float((label_scores > 0).sum()) / float(label_scores.shape[0])


# for i in range(3):
#     print(adsf(i))



# start = time.time()

# all_ex_inds = np.concatenate([comp_exp.concept_example_indices, comp_exp.negative_example_indices], axis=0)

# labels = np.zeros([all_ex_inds.shape[0]], dtype=np.int32)
# labels[:comp_exp.concept_example_indices.shape[0]] = 1

# examples = comp_exp.activations[all_ex_inds]


# # Try scikitlearn for binary regression.
# clf = LogisticRegression(max_iter=1000, fit_intercept=False).fit(examples, labels)
# print(time.time() - start)


# # liblinear