"""Stuff related to HANS."""
import collections
import dataclasses
import functools

import numpy as np

from em.projects.ll import hans_util
from em.projects.ll import hans_labeling

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi import qqp_components_context as QCC


###############################################################################

def _get_hans_le_examples__process_fn(x):
    return x['heuristic'] == 'lexical_overlap'


def get_hans_le_examples(split: str):
    return hans_util.get_first_hans_examples(
        split,
        n_examples=10_000,
        process_fn=lambda ds: ds.filter(_get_hans_le_examples__process_fn),
    )


###############################################################################


def _mask_dict_to_inds(masks):
    return {
        k: np.nonzero(v)[0]
        for k, v in masks.items()
    }


@dataclasses.dataclass
class HansEvalResults:
    eval_results: QCC.QqpEvaluationResults
    helper: 'HansHelper1'

    def _get_metric_by(self, metric: str, inds: str):
        return {
            k: getattr(self.eval_results, f'{metric}_for_examples')(v)
            for k, v in getattr(self.helper, inds).items()
        }

    get_acc_by_subcase = functools.partialmethod(_get_metric_by, 'acc', 'subcase_inds')
    get_acc_by_template = functools.partialmethod(_get_metric_by, 'acc', 'template_inds')

    get_loss_by_subcase = functools.partialmethod(_get_metric_by, 'loss', 'subcase_inds')
    get_loss_by_template = functools.partialmethod(_get_metric_by, 'loss', 'template_inds')

    get_kl_by_subcase = functools.partialmethod(_get_metric_by, 'kl', 'subcase_inds')
    get_kl_by_template = functools.partialmethod(_get_metric_by, 'kl', 'template_inds')


@dataclasses.dataclass
class HansHelper1:
    exp: ablation_exp_util.Experiment1

    split: str

    def __post_init__(self):
        self.hans_examples = get_hans_le_examples(self.split)

        self.subcase_inds = _mask_dict_to_inds(hans_labeling.compute_subcase_indicators(self.hans_examples))
        self.template_inds = _mask_dict_to_inds(hans_labeling.compute_template_indicators(self.hans_examples))

        # TODO: Some validation that the hans_examples are consistent with the examples in the eval_ctx.
        self.eval_ctx = self.exp.mc.get_evaluation_context()

    def evaluate(self, model):
        return HansEvalResults(
            eval_results=self.eval_ctx.evaluate(model),
            helper=self,
        )

    ######################################################################

    def _get_is_tuned_dict(self, top_ex_inds, inds_dict, min_fraction):
        top_ex_inds = set(top_ex_inds)
        ret = {}
        for k, v in inds_dict.items():
            frac = len(v & top_ex_inds) / len(top_ex_inds)
            ret[k] = frac >= min_fraction
        return ret

    def _compute_component_tuning_infos(self, inds_dict: str, n_examples: int, min_fraction: float):
        W = self.exp.nmf.W
        inds_dict = {k: set(v) for k, v in getattr(self, inds_dict).items()}
        ret = collections.defaultdict(list)
        for comp_index in range(W.shape[-1]):
            coeffs = W[:, comp_index]
            top_ex_inds = np.argsort(-coeffs)[:n_examples]
            for k, is_tuned in self._get_is_tuned_dict(top_ex_inds, inds_dict, min_fraction).items():
                if is_tuned:
                    ret[k].append(comp_index)
        return {k: np.array(v, dtype=np.int32) for k, v in ret.items()}

    compute_component_tuning_infos_by_subcase = functools.partialmethod(_compute_component_tuning_infos, 'subcase_inds')
    compute_component_tuning_infos_by_template = functools.partialmethod(_compute_component_tuning_infos, 'template_inds')
