R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i em/projects/pb/signal_peptide/devmains/explore_comps001.py
"""
import dataclasses
from importlib import reload
import random
import os

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.fishers import per_example
from em.tools.nmf import nmf_common

from em.projects.pb.signal_peptide.contexts import sp_npeff_context


##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pb/signal_peptide'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'Rostlab/prot_bert'

##########################################################################

EPOCH = 0
SPLIT = 'train'

PEF_FILENAME = f"prot_bert.epoch_{EPOCH}.train.20290ex.131072.h5"
NMF_FILENAME = f"spH.nmf_decomp.c{128}_{2500}Iters_{65536}pe_mvpp{4}_{20290}ex.{PEF_FILENAME}"

##########################################################################


def print_highest(ctx, component_index: int, n_examples: int):
    coeffs = ctx.nmf.W[:, component_index]
    indices = np.argsort(-coeffs)[:n_examples]
    #
    for index in indices:
        ex = ctx.examples[index]
        print(f'{ex.prediction} {ex.label} {coeffs[index]:0.4f} {ex.aa_sequence} {ex.kingdom}')


def print_highest_as_csv(ctx, component_index: int, n_examples: int):
    coeffs = ctx.nmf.W[:, component_index]
    rows = [['Prediction', 'Label', 'Coeff', 'Kingdom', 'UniProt AC', 'AA Sequence']]
    for ex in ctx.get_top_examples(component_index, n_examples):
        rows.append([
            ex.prediction,
            ex.get_label_as_str(),
            coeffs[ex.index],
            ex.kingdom,
            ex.uniprot_ac,
            ex.aa_sequence,
        ])
    rows = [','.join([str(c) for c in r]) for r in rows]
    print('\n'.join(rows))


##########################################################################


ctx = sp_npeff_context.NpeffContext.load(
    pef_path=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_path=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=TOKENIZER,
    split=SPLIT,
)

# 23, 24, 32, 35, 42, 57, 81, 114, 119

# print_highest(ctx, n_examples=16, component_index=23)
print_highest_as_csv(ctx, n_examples=64, component_index=114)


"""
print_highest(ctx, n_examples=10, component_index=23)
1 0.3876 MVCLKLPGGSCMTALTVTLMVLSSPLALSGDTRPRFLWQPKRECHFFNGTERVRFLDRYFYNQEESVRFD
1 0.3819 MVCLKLPGGSCMTALTVTLMVLSSPLALAGDTRPRFLWQLKFECHFFNGTERVRLLERCIYNQEESVRFD
1 0.3648 MVCLKLPGGSYMAKLTVTLMVLSSPLALAGDTRPRFLQQDKYECHFFNGTERVRFLHRDIYNQEEDLRFD
1 0.3579 MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRFISVGYVDDTQFVRFDSDAASPREE
1 0.3534 MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME
1 0.3378 MMLRVLVGAVLPAMLLAAPPPINKLALFPDKSAWCEAKNITQIVGHSGCEAKSIQNRACLGQCFSYSVPN
1 0.3309 MKETKHQHTFSIRKSAYGAASVMVASCIFVIGGGVAEANDSTTQTTTPLEVAQTSQQETHTHQTPVTSLH
1 0.3261 MIHSKKLTLGICLVLLIILIGGCVIMTKTNGRNAQIKENFNKTLSVYLTKNLDDFYDKEGFRDQEFDKRD
1 0.3190 MIHSKKLTLGICLVLLIILIGGCIIMTKINSRNAQIKDTFNQTLNVYPTKNLDDFYDKEGFRDQEFDKRD
1 0.3187 MSWKKALRIPGGLRAATVTLMLAMLSTPVAEGRDSPEDFVYQFKAMCYFTNGTERVRYVTRYIYNREEYA

print_highest(ctx, n_examples=10, component_index=24)
1 0.5763 MANNSSYGENVRRKSHTPSAIVIGSGFAGIAAANALRNASFEVVLLESRDRIGGRIHTDYSFGFPVDLGA
1 0.5398 MDPNSLKTGGLLLPTIERQCASPPSVIVIGGGISGVAAARALSNASFEVTVLESRDRVGGRVHTDYSFGC
1 0.5062 MESRKNSDRQMRRANCFSAGERMKTRSPSVIVIGGGFGGISAARTLQDASFQVMVLESRDRIGGRVHTDY
1 0.4905 MAPSAGEDKHSSAKRVAVIGAGVSGLAAAYKLKIHGLNVTVFEAEGKAGGKLRSVSQDGLIWDEGANTMT
1 0.4485 MDQPSNGFAAGGLFLRHIDGQNASPPSVIVIGGGISGIAAARALSNASFKVTLLESRDRLGGRVHTDYSF
1 0.3608 MDKKKNSFPDNLPEGTISELMQKQNNVQPSVIVIGSGISGLAAARNLSEASFKVTVLESRDRIGGRIHTD
1 0.3417 MESGGKTNRQLRKAICVSTDEKMKKKRSPSVIVIGGGMAGISAARTLQDASFQVVVLESRDRIGGRVHTD
1 0.3333 MFYPDPFDVIIIGGGHAGTEAAMAAARMGQQTLLLTHNIDTLGQMSCNPAIGGIGKGHLVKEVDALGGLM
1 0.3332 MDFCWKREMEGKLAHDHRGMTSPRRICVVTGPVIVGAGPSGLATAACLKERGITSVLLERSNCIASLWQL
1 0.3138 MIKHYDVVIAGGGVIGASCAYQLSKRKDLKVALIDAKRPGNASRASAGGLWAIGESVGLGCGVIFFRMMS

print_highest(ctx, n_examples=10, component_index=32)
1 0.1289 MKLTKTALCTALFATFTFSANAQTYPDLPVGIKGGTGALIGDTVYVGLGSGGDKFYTLDLKDPSAQWKEI
1 0.1202 MAAASPSVFLLMITGQVESAQFPEYDDLYCKYCFVYGQDWAPTAGLEEGISQIASKSQDVRQALVWNFPI
1 0.1180 MKNVFKTLAVLLTLFSLTGCGLKGPLYFPPADKNAPPPTKKVDSQTQSTMPDKNDRATGDGPSQVNY
1 0.1160 MAPQTSNLWILLLLVVVMMMSQGCCQHWSYGLSPGGKRDLDSLSDTLGNIIERFPHVDSPCSVLGCVEEP
1 0.1144 MRKRISAIIMTLFMVFMSCNNGGPELKSDEVAKSDGTVLDLAKVSKKIKEASAFAASVKEVETLVKSVDE
1 0.1104 MKKTKFFLLGLAALAMTACNKDNEAEPVTEGNATISVVLKTSNSNRAFGVGDDESKVAKLTVMVYNGEQQ
1 0.1102 MKKLTTLLLASTLLIAACGNDDSKKDDSKTSKKDDGVKAELKQATKAYDKYTDEQLNEFLKGTEKFVKAI
1 0.1085 MNKIHVTYKNLLLPITFIAATLISACDNDKDAMAEAEKNQEKYMQKIQQKEHQQSMFFYDKAEMQKAIAN
1 0.1080 MKLRWFAFLVVILAGCSSKQDYRNPPWNAEVPVKRAMQWMPISEKAGAAWGVDPHLITAIIAIESGGNPN
1 0.1069 MKKLTTLLLASTLLIAACGNDDSKKDDSKTSKKDDGVKAELKQATKAYDKYTDEQLNEFLKGTEKFVKAI

print_highest(ctx, n_examples=10, component_index=35)
1 0.3901 MSAGSPKFTVSRIAALSLVSLWLAGCTSSSNPPAPVTSVDSGSSSNTNSGMLITPPPKMGATTQQTPQQA
1 0.3752 MTLPDFRLIRLLPLASLVLTACTLPVHKEPGKSPDSPQWRQHQQEVRNLNQYQTRGAFAYISDDQKVYAR
1 0.3704 MKTKTILTALLSAIALTGCANNNDTKQVSERNDSLEDFNRTMWKFNYNVIDRYVLEPAAKGWNNYVPKPI
1 0.3677 MKKKLLAGAITLLSVATLAACSKGSEGADLISMKGDVITEHQFYEQVKNNPSAQQVLLNMTIQKVFEKQY
1 0.3660 MKLNKKHLVAILSVLSLSIIVVPLLTSCTGDIPELNPAEIINTLFPNVWVFIAQVIAMCVVFSLVLWLVW
1 0.3652 MPLPDFRLIRLLPLAALVLTACSVTTPKGPGKSPDSPQWRQHQQDVRNLNQYQTRGAFAYISDQQKVYAR
1 0.3602 MKKALLALFMVVSIAALAACGAGNDNQSKDNAKDGDLWASIKKKGVLTVGTEGTYEPFTYHDKDTDKLTG
1 0.3580 MNNLKRFTKSIFSCIALSGLLFLGGCETLPPTTDLSPITVDNAAQAKAWELQGKLAIRTPEDKLSANLYW
1 0.3578 MPLPDFRLIRLLPLASLVLTACTITSPKGPGKSPDSPQWRQHQQDVRNLNQYQTRGAFAYISDQQKVYAR
1 0.3567 MTLRSFLIFFLSSLILAGCSSVPESVTSVEWQAHEQRLETIHDFQATGKLGYIGPDQRQSLNFFWKHSTA

print_highest(ctx, n_examples=10, component_index=42)
1 0.0807 MARSLVCLGVIILLSAFSGPGVRGGPMPKLADRKLCADQECSHPISMAVALQDYMAPDCRFLTIHRGQVV
1 0.0752 MASTKLFFSVITVMMLIAMASEMVNGSAFTVWSGPGCNNRAERYSKCGCSAIHQKGGYDFSYTGQTAALY
1 0.0712 MKPPRPVRTCSKVLVLLSLLAIHQTTTAEKNGIDIYSLTVDSRVSSRFAHTVVTSRVVNRANTVQEATFQ
1 0.0699 MKTLLLTLVVVTIVCLDFGHTMICYNQQSSQPPTTTTCSEGQCYKQRWRDHRGWRTERGCGCPKAIPEVK
1 0.0688 MRALEGPGLSLLCLVLALPALLPVPAVRGVAETPTYPWRDAETGERLVCAQCPPGTFVQRPCRRDSPTTC
1 0.0683 MVMGLGVLLLVFVLGLGLTPPTLAQDNSRYTHFLTQHYDAKPQGRDDRYCESIMRRRGLTSPCKDINTFI
1 0.0682 MARNMAHILHILVISLSYSFLFVSSSSQDSQSLYHNSQPTSSKPNLLVLPVQEDASTGLHWANIHKRTPL
1 0.0680 MSGMWVHPGRTLIWALWVLAAVIKGPAADAPVRSTRLGWVRGKQTTVLGSTVPVNMFLGIPYAAPPLGPL
1 0.0677 MGSLANNIMVVGAVLAALVAGGSCGPPKVPPGPNITTNYNGKWLTARATWYGQPNGAGAPDNGGACGIKN
1 0.0655 MASAKIFLIFLLAALIATPAAFAILVPTLVSTHISGLVFCSVNGNLDVINGLSPQVFPNASVQLRCGATN



"""
