R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i em/projects/pb/signal_peptide/devmains/explore_comps002.py
"""
import dataclasses
from importlib import reload
import random
import os

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.fishers import per_example
from em.tools.nmf import nmf_common

from em.projects.pb.signal_peptide.analysis import sp_component_tunings
from em.projects.pb.signal_peptide.contexts import sp_npeff_context


##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pb/signal_peptide'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'Rostlab/prot_bert'

##########################################################################

EPOCH = 7
SPLIT = 'train'

PEF_FILENAME = f"prot_bert.epoch_{EPOCH}.train.20290ex.131072.h5"
NMF_FILENAME = f"spH.nmf_decomp.c{128}_{2500}Iters_{65536}pe_mvpp{4}_{20290}ex.{PEF_FILENAME}"

##########################################################################


def print_highest(ctx, component_index: int, n_examples: int, *, include_annotation_sequence: bool = False):
    coeffs = ctx.nmf.W[:, component_index]
    indices = np.argsort(-coeffs)[:n_examples]
    #
    for index in indices:
        ex = ctx.examples[index]
        print(f'{ex.prediction} {ex.label} {coeffs[index]:0.4f} {ex.aa_sequence} {ex.kingdom}')
        if include_annotation_sequence:
            print((11 * ' ') + ex.annotation_sequence.strip())


def print_highest_as_csv(ctx, component_index: int, n_examples: int, *, include_annotation_sequence: bool = False):
    coeffs = ctx.nmf.W[:, component_index]
    rows = [['Prediction', 'Label', 'Coeff', 'Kingdom', 'UniProt AC', 'AA Sequence']]
    for ex in ctx.get_top_examples(component_index, n_examples):
        rows.append([
            ex.prediction,
            ex.get_label_as_str(),
            coeffs[ex.index],
            ex.kingdom,
            ex.uniprot_ac,
            ex.aa_sequence,
        ])
    rows = [','.join([str(c) for c in r]) for r in rows]
    print('\n'.join(rows))


##########################################################################


ctx = sp_npeff_context.NpeffContext.load(
    pef_path=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_path=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=TOKENIZER,
    split=SPLIT,
)

# [  1  22  28  41  47  56  75  78  82  84  87  92  95  97 100 108 124 127]


# print_highest(ctx, n_examples=16, component_index=47)
print_highest(ctx, include_annotation_sequence=True, n_examples=16, component_index=47)
# print_highest_as_csv(ctx, n_examples=64, component_index=114)


# tuned_component_inds = sp_component_tunings.get_component_tuned_for_positive_predictions(
#     ctx,
#     n_top_examples=16,
#     min_positive_top_examples=10,
# )
# print(tuned_component_inds)

# 1?, 22, 28, ...
