R"""

cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


EVAL_TASK=rte
RTE_MODEL=textattack/roberta-base-RTE
MNLI_MODEL=textattack/roberta-base-MNLI
FISHER_DIR=~/Desktop/projects_data/extract_merge1/fishers0

# Compute RTE Fisher.
python3 ./scripts1/ogmm/compute_fisher.py  \
    --model=$RTE_MODEL \
    --glue_task="rte" \
    --fisher_path="$FISHER_DIR/rte_fisher.h5"

# Compute MNLI Fisher.
python3 ./scripts1/ogmm/compute_fisher.py  \
    --model=$MNLI_MODEL \
    --glue_task="mnli" \
    --fisher_path="$FISHER_DIR/mnli_fisher.h5"

# Fisher merge
python3 ./scripts1/ogmm/merge_and_evaluate.py  \
    --models=$RTE_MODEL,$MNLI_MODEL \
    --fishers=$FISHER_DIR/rte_fisher.h5,$FISHER_DIR/mnli_fisher.h5 \
    --glue_task=$EVAL_TASK

###############################################################################

cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/misc1/sparse_dev001.py

"""
import os

from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow as tf

from em.datasets import glue
from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import sparse_diagonal
from em.merging import merging
from em.util import hf_util
from em.util import vat_da_faak_vpn


FISHER_DIR = os.path.expanduser('~/Desktop/projects_data/extract_merge1/fishers0')
MODELS = ['textattack/roberta-base-RTE', 'textattack/roberta-base-MNLI']
EVAL_TASK = 'rte'


rte_dense = diagonal.DiagonalFisher.load(os.path.join(FISHER_DIR, 'rte_fisher.h5'))
mnli_dense = diagonal.DiagonalFisher.load(os.path.join(FISHER_DIR, 'mnli_fisher.h5'))


sparsity = 1 / 10_000
rte_sparse = sparse_diagonal.from_dense_uniformly(rte_dense, sparsity)
mnli_sparse = sparse_diagonal.from_dense_uniformly(mnli_dense, sparsity)

# TODO: Handle SparseTensors directly.
rte_sparse.fishers = [tf.sparse.to_dense(f) for f in rte_sparse.fishers]
mnli_sparse.fishers = [tf.sparse.to_dense(f) for f in mnli_sparse.fishers]


"""
See how much the sparse masks have in common between tasks as stuff gets more sparse.
I think there's a chance that there's a decent overlap in parameters that have the similar
values since they are generally useful for the model either from a plumbing perspective or
to generally process text. These can have high Fishers.
"""


# for f1, f2 in zip(rte_sparse.fishers, mnli_sparse.fishers):
#     m1 = f1 > 0
#     m2 = f2 > 0
#     mm = m1 & m2
#     #
#     m1 = tf.reduce_sum(tf.cast(m1, tf.int32)).numpy()
#     m2 = tf.reduce_sum(tf.cast(m2, tf.int32)).numpy()
#     mm = tf.reduce_sum(tf.cast(mm, tf.int32)).numpy()
#     print(f'{mm} / {(m1 + m2) // 2}')


###############################################################################


def load_models():
    models = []
    for i, model_str in enumerate(MODELS):
        model_str = os.path.expanduser(model_str)
        model = TFAutoModelForSequenceClassification.from_pretrained(
            model_str, from_pt=True
        )
        models.append(model)
        if i == 0:
            tokenizer = AutoTokenizer.from_pretrained(model_str)
    return models, tokenizer


models, tokenizer = load_models()
# fishers = [rte_sparse.fishers, mnli_sparse.fishers]
fishers = [rte_dense.fishers, mnli_dense.fishers]
variables = [hf_util.get_mergeable_variables(m) for m in models]


def something(variables1, variables2, fisher1, fisher2):
    kl1g2s = []
    kl2g1s = []
    for v1, v2, f1, f2 in zip(variables1, variables2, fisher1, fisher2):
        sq_delta = (v1 - v2)**2
        kl1g2 = f1 * sq_delta
        kl2g1 = f2 * sq_delta
        kl1g2s.append(kl1g2)
        kl2g1s.append(kl2g1)
    return kl1g2s, kl2g1s 


kl1g2s, kl2g1s = something(*variables, *fishers)

# sparsity = 1 / 10_000
# sparsity = 1 / 100
# sparsity = 1 / 1000
# sparsity = 1 / 100_000
# sparsity = 1 / 100_000_000
sparsity = 1 / 10
rte_sparse_mask = sparse_diagonal.from_dense_uniformly(diagonal.DiagonalFisher(kl1g2s), sparsity)
mnli_sparse_mask = sparse_diagonal.from_dense_uniformly(diagonal.DiagonalFisher(kl2g1s), sparsity)

# TODO: Handle SparseTensors directly.
rte_sparse_mask.fishers = [tf.sparse.to_dense(f) > 0 for f in rte_sparse_mask.fishers]
mnli_sparse_mask.fishers = [tf.sparse.to_dense(f) > 0 for f in mnli_sparse_mask.fishers]


# NOTE: I forget to multiply by f here. Somehow this works better?
rte_sparse2 = [
    tf.cast(m, f.dtype) 
    for m, f in zip(rte_sparse_mask.fishers, rte_dense.fishers)
]
mnli_sparse2 = [
    tf.cast(m, f.dtype) 
    for m, f in zip(mnli_sparse_mask.fishers, mnli_dense.fishers)
]

fishers = [rte_sparse2, mnli_sparse2]

###############################################################################


ds = glue.load_glue_dataset(
    task=EVAL_TASK,
    split='validation',
    tokenizer=tokenizer,
    max_length=128,
)
ds = ds.take(4096).batch(32)

metric = evaluation.load_metric_for_glue_task(EVAL_TASK)

coefficients_set = merging.create_pairwise_grid_coeffs(50)

results = merging.merging_coefficients_search(
    models,
    coefficients_set=coefficients_set,
    dataset=ds,
    metric=metric,
    fishers=fishers,
    favor_target_model=True,
    #
    # fisher_floor=1e-15,
    normalize_fishers=True,
    fisher_floor=1e-6,
    # normalize_fishers=False,
)
