"""Stuff specific for scitail ablations."""
from typing import Any, Dict, Optional, Sequence, Tuple

import numpy as np
from scipy import special
import tensorflow as tf

from em import datasets as em_datasets
from em.util import flat_pack

from em.projects.ll import hans_labeling_analysis as hla

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC

###############################################################################


def get_scitail_eval_ctx(
    mc,
    n_examples=None,
    split='validation',
    sequence_length=64,
):
    ds = em_datasets.load(
        'sci_tail/default',
        split=split,
        sequence_length=sequence_length,
        tokenizer=mc.tokenizer,
    )
    if n_examples is not None:
        ds = ds.take(n_examples)
    return QCC.EvaluationContext2.create_from_ds(
        ds=ds.cache(),
        model=mc.model,
    )


def get_rte_eval_ctx(
    mc,
    n_examples=None,
    split='validation',
    sequence_length=64,
):
    ds = em_datasets.load(
        'glue/rte',
        split=split,
        sequence_length=sequence_length,
        tokenizer=mc.tokenizer,
    )
    if n_examples is not None:
        ds = ds.take(n_examples)
    return QCC.EvaluationContext2.create_from_ds(
        ds=ds.cache(),
        model=mc.model,
    )

###############################################################################


def sparse_tensor_to_fisher(mc, spF: tf.sparse.SparseTensor):
    packer = flat_pack.FlatPacker([v.shape for v in mc.variables])
    return packer.decode_tf(tf.sparse.to_dense(spF))


def make_fisher_linear_combination(mc, component_indices, coeffs):
    assert component_indices.shape == coeffs.shape

    component_indices = np.array(list(sorted(component_indices)), dtype=np.int32)
    
    nmf, = mc.container.nmfs
    spH = nmf.get_full_sparse_H()
    spH = [spH[i] for i in component_indices]
    
    terms = [tf.cast(c, tf.float32) * h for c, h in zip(coeffs, spH)]

    ret = terms[0]
    for term in terms[1:]:
        ret = tf.sparse.add(ret, term)
    return sparse_tensor_to_fisher(mc, ret)


def make_fisher_geometric_mean(mc, component_indices, coeffs, fisher_floor: float = 1e-7):
    assert component_indices.shape == coeffs.shape

    component_indices = np.array(list(sorted(component_indices)), dtype=np.int32)
    
    nmf, = mc.container.nmfs
    spH = nmf.get_full_sparse_H()
    spH = [spH[i] for i in component_indices]
    
    # terms = [tf.cast(c, tf.float32) * h for c, h in zip(coeffs, spH)]
    H = tf.stack([tf.minimum(tf.sparse.to_dense(h), fisher_floor) for h in spH], axis=0)

    F = tf.math.exp(tf.reduce_sum(tf.math.log(H) * coeffs[:, None], axis=0) / tf.reduce_sum(coeffs))
   
    packer = flat_pack.FlatPacker([v.shape for v in mc.variables])
    return packer.decode_tf(F)




