R"""

cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


EVAL_TASK=rte
RTE_MODEL=textattack/roberta-base-RTE
MNLI_MODEL=textattack/roberta-base-MNLI
FISHER_DIR=~/Desktop/projects_data/extract_merge1/fishers0

# Compute RTE Fisher.
python3 ./scripts1/ogmm/compute_fisher.py  \
    --model=$RTE_MODEL \
    --glue_task="rte" \
    --fisher_path="$FISHER_DIR/rte_fisher.h5"

# Compute MNLI Fisher.
python3 ./scripts1/ogmm/compute_fisher.py  \
    --model=$MNLI_MODEL \
    --glue_task="mnli" \
    --fisher_path="$FISHER_DIR/mnli_fisher.h5"

# Fisher merge
python3 ./scripts1/ogmm/merge_and_evaluate.py  \
    --models=$RTE_MODEL,$MNLI_MODEL \
    --fishers=$FISHER_DIR/rte_fisher.h5,$FISHER_DIR/mnli_fisher.h5 \
    --glue_task=$EVAL_TASK

###############################################################################

cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/misc1/sparse_dev002.py

"""
import os

from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow as tf

from em.datasets import glue
from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import sparse_diagonal
from em.merging import merging
from em.util import hf_util
from em.util import vat_da_faak_vpn


FISHER_DIR = os.path.expanduser('~/Desktop/projects_data/extract_merge1/fishers0')
EVAL_TASK = 'rte'


rte_dense = diagonal.DiagonalFisher.load(os.path.join(FISHER_DIR, 'rte_fisher.h5'))
mnli_dense = diagonal.DiagonalFisher.load(os.path.join(FISHER_DIR, 'mnli_fisher.h5'))


rte_model = TFAutoModelForSequenceClassification.from_pretrained(
    'textattack/roberta-base-RTE', from_pt=True)
mnli_model = TFAutoModelForSequenceClassification.from_pretrained(
    'textattack/roberta-base-MNLI', from_pt=True)

base_model = TFAutoModelForSequenceClassification.from_pretrained(
    'roberta-base', from_pt=True)

tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-RTE')


# sparsity = 1 / 100_000
# sparsity = 1 / 10_000_000
# sparsity = 1 / 100
# sparsity = 1 / 10
sparsity = 1 / 20
rte_sparse = sparse_diagonal.from_dense_by_metric_approximation(
    rte_dense,
    hf_util.get_mergeable_variables(rte_model),
    # hf_util.get_mergeable_variables(mnli_model),
    hf_util.get_mergeable_variables(base_model),
    sparsity,
)
mnli_sparse = sparse_diagonal.from_dense_by_metric_approximation(
    mnli_dense,
    hf_util.get_mergeable_variables(mnli_model),
    # hf_util.get_mergeable_variables(rte_model),
    hf_util.get_mergeable_variables(base_model),
    sparsity,
)

# TODO: Handle SparseTensors directly.
rte_sparse.fishers = [tf.sparse.to_dense(f) for f in rte_sparse.fishers]
mnli_sparse.fishers = [tf.sparse.to_dense(f) for f in mnli_sparse.fishers]

###############################################################################


ds = glue.load_glue_dataset(
    task=EVAL_TASK,
    split='validation',
    tokenizer=tokenizer,
    max_length=128,
)
ds = ds.take(4096).batch(32)

metric = evaluation.load_metric_for_glue_task(EVAL_TASK)

coefficients_set = merging.create_pairwise_grid_coeffs(50)

results = merging.merging_coefficients_search(
    [rte_model, mnli_model],
    coefficients_set=coefficients_set,
    dataset=ds,
    metric=metric,
    fishers=[rte_sparse.fishers, mnli_sparse.fishers],
    favor_target_model=True,
    #
    # fisher_floor=1e-15,
    normalize_fishers=True,
    fisher_floor=1e-6,
    # normalize_fishers=False,
)
