"""Fishers coming from a subset of components from an NMF.

When writing this stuff, I am assuming that there are unlabeled examples
in the PefNmfAnalysisContainer, so whatever code I write here should
handle them.
"""
import dataclasses
import json
import os
from typing import Any, List, Optional, Sequence, Set

import h5py
import numpy as np
import tensorflow as tf

from em.util import flat_pack
from em.util import hdf5_util

from em.projects.anli import anli_misc1 as am
from em.projects.anli import nli_example

from em.util.color_util import cu

# typedefs
Json = Any
NliExample = nli_example.NliExample


###############################################################################

def _safe_asdict(x):
    if x is None:
        return x
    return dataclasses.asdict(x)


def _replace_nans(x):
    x[np.isnan(x)] = 0
    return x


###############################################################################

@dataclasses.dataclass
class ExamplesAccuracyInfo:
    component_examples_accuracy: float
    remaining_examples_accuracy: float

    n_component_examples: int
    n_total_examples: int


###########################################################

@dataclasses.dataclass
class SelectionParameters:
    coeff_factor: float

    frac_threshold: float
    p_value_threshold: float

    max_examples: Optional[int] = None


@dataclasses.dataclass
class ComponentSelectivityInfo:
    nmf_index: int
    component_index: int

    fraction: float
    p_value: float

    # These will only be the top labeled examples.
    labeled_examples: Sequence[NliExample] = ()

    ##################################################

    def get_all_relevant_example_indices(self, container: am.PefNmfAnalysisContainer) -> Sequence[int]:
        # Returns the indices of examples with coefficient greater than or equal
        # the smallest coefficient of this component's relevant labeled examples.
        assert len(self.labeled_examples)
        coeffs = container.nmfs[self.nmf_index].W[:, self.component_index]
        min_labeled_example = min(self.labeled_examples, key=lambda x: coeffs[x.index])
        min_labeled_coeff = coeffs[min_labeled_example.index]
        inds, = np.nonzero(coeffs >= min_labeled_coeff)
        return list(sorted(inds, key=lambda i: -coeffs[i]))

    ##################################################

    def to_serializable_dict(self) -> Json:
        return {
            'nmf_index': int(self.nmf_index),
            'component_index': int(self.component_index),
            'fraction': float(self.fraction),
            'p_value': float(self.p_value),
        }

    @classmethod
    def from_serializable_dict(cls, dikt: Json):
        return cls(**dikt)


@dataclasses.dataclass
class SubsetInfo:
    correct_fisher: np.ndarray
    erroring_fisher: np.ndarray

    correct_component_infos: Sequence[ComponentSelectivityInfo]

    reduce_kept_indices: Optional[np.ndarray] = None
    full_dense_size: Optional[int] = None

    variable_shapes: Optional[Sequence[tf.TensorShape]] = None

    def get_all_labeled_example_indices(self) -> Set[int]:
        ret = set()
        for cci in self.correct_component_infos:
            ret.update(f.index for f in cci.labeled_examples)
        return ret

    ##################################################

    def to_dense_diagonal_fisher(self, fisher: np.ndarray) -> Sequence[tf.Tensor]:
        # fisher should either be self.correct_fisher or self.erroring_fisher
        assert self.variable_shapes is not None
        assert self.full_dense_size is not None

        # TODO: Make sure full_dense_size and self.variable_shapes are compatible.

        if self.reduce_kept_indices is None:
            expanded_fisher = fisher
        else:
            expanded_fisher = np.zeros([self.full_dense_size], dtype=fisher.dtype)
            expanded_fisher[self.reduce_kept_indices] = fisher

        packer = flat_pack.FlatPacker(self.variable_shapes)

        return packer.decode_tf(expanded_fisher)

    def to_correct_dense_diagonal_fisher(self) -> Sequence[tf.Tensor]:
        return self.to_dense_diagonal_fisher(self.correct_fisher)

    def to_erroring_dense_diagonal_fisher(self) -> Sequence[tf.Tensor]:
        return self.to_dense_diagonal_fisher(self.erroring_fisher)

    ##################################################

    def _to_json_for_saving(self) -> Json:
        # This does NOT include any np arrays. They are saved separately as h5 datasets.
        variable_shapes = None
        if self.variable_shapes is not None:
            variable_shapes = [[int(d) for d in list(s)] for s in self.variable_shapes]

        full_dense_size = None
        if self.full_dense_size is not None:
            full_dense_size = int(self.full_dense_size)

        return {
            'correct_component_infos': [c.to_serializable_dict() for c in self.correct_component_infos],
            'full_dense_size': full_dense_size,
            'variable_shapes': variable_shapes,
        }


@dataclasses.dataclass
class NmfComponentsFisher:
    subset_infos: Sequence[SubsetInfo]

    selection_parameters: Optional[SelectionParameters] = None

    examples_accuracy_info: Optional[ExamplesAccuracyInfo] = None

    # If not None, both are assumed to be dense fishers with an element
    # for each relevant model variable.
    batch_correct_fishers: Optional[Sequence[np.ndarray]] = None
    batch_erroring_fishers: Optional[Sequence[np.ndarray]] = None

    extra_info: Json = None

    def __post_init__(self):
        if self.extra_info is None:
            self.extra_info = {}

    def get_all_labeled_example_indices(self) -> Set[int]:
        ret = set()
        for s in self.subset_infos:
            ret.update(s.get_all_labeled_example_indices())
        return ret

    ##################################################

    def compute_examples_accuracy_info(
        self,
        container: am.PefNmfAnalysisContainer,
        *,
        update_self: bool = True,
    ) -> ExamplesAccuracyInfo:
        assert container.n_nmfs == len(self.subset_infos)

        example_indices = self.get_all_labeled_example_indices()
        labeled_mask = ~container.unlabeled_indicator
        n_labeled = labeled_mask.sum()

        component_examples_accuracy = np.mean([
            float(container.examples[i].is_correctly_labeled())
            for i in example_indices
        ])

        remaining_examples_accuracy = np.mean([
            float(e.is_correctly_labeled())
            for e in container.examples
            if labeled_mask[e.index] and e.index not in example_indices
        ])

        ret = ExamplesAccuracyInfo(
            component_examples_accuracy=float(component_examples_accuracy),
            remaining_examples_accuracy=float(remaining_examples_accuracy),
            n_component_examples=len(example_indices),
            n_total_examples=int(n_labeled),
        )

        if update_self:
            self.examples_accuracy_info = ret

        return ret

    ##################################################
    
    def to_correct_dense_diagonal_fisher(self) -> Sequence[tf.Tensor]:
        ret = []
        for ssi in self.subset_infos:
            ret.extend(ssi.to_correct_dense_diagonal_fisher())
        return ret

    def to_erroring_dense_diagonal_fisher(self) -> Sequence[tf.Tensor]:
        ret = []
        for ssi in self.subset_infos:
            ret.extend(ssi.to_erroring_dense_diagonal_fisher())
        return ret

    ##################################################

    def _make_json_for_saving(self) -> Json:
        return {
            'subset_infos': [s._to_json_for_saving() for s in self.subset_infos],
            'selection_parameters': _safe_asdict(self.selection_parameters),
            'examples_accuracy_info': _safe_asdict(self.examples_accuracy_info),
            'extra_info': self.extra_info,
        }

    def save(self, filepath: str):
        correct_fishers = [s.correct_fisher for s in self.subset_infos]
        erroring_fishers = [s.erroring_fisher for s in self.subset_infos]

        reduce_kept_indices = [s.reduce_kept_indices for s in self.subset_infos]
        if any(r is None for r in reduce_kept_indices):
            assert all(r is None for r in reduce_kept_indices)
            reduce_kept_indices = None

        json_info = json.dumps(self._make_json_for_saving())
        
        with h5py.File(os.path.expanduser(filepath), "w") as f:
            data = f.create_group('data')
            data.attrs['json_info'] = json_info

            hdf5_util.save_np_arrays_to_group(
                data.create_group('correct_fishers'),
                correct_fishers
            )
            hdf5_util.save_np_arrays_to_group(
                data.create_group('erroring_fishers'),
                erroring_fishers
            )

            if reduce_kept_indices is not None:
                hdf5_util.save_np_arrays_to_group(
                    data.create_group('reduce_kept_indices'),
                    reduce_kept_indices
                )

            if self.batch_correct_fishers is not None:
                hdf5_util.save_np_arrays_to_group(
                    data.create_group('batch_correct_fishers'),
                    self.batch_correct_fishers
                )
            if self.batch_erroring_fishers is not None:
                hdf5_util.save_np_arrays_to_group(
                    data.create_group('batch_erroring_fishers'),
                    self.batch_erroring_fishers
                )

    @classmethod
    def load(cls, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            json_info = json.loads(f['data'].attrs['json_info'])

            correct_fishers = hdf5_util.load_np_arrays_from_group(f['data/correct_fishers'])
            erroring_fishers = hdf5_util.load_np_arrays_from_group(f['data/erroring_fishers'])

            if 'data/reduce_kept_indices' in f:
                reduce_kept_indices = hdf5_util.load_np_arrays_from_group(f['data/reduce_kept_indices'])
            else:
                reduce_kept_indices = len(correct_fishers) * [None]

            batch_correct_fishers, batch_erroring_fishers = None, None
            if 'data/batch_correct_fishers' in f:
                batch_correct_fishers = hdf5_util.load_np_arrays_from_group(f['data/batch_correct_fishers'])
            if 'data/batch_erroring_fishers' in f:
                batch_erroring_fishers = hdf5_util.load_np_arrays_from_group(f['data/batch_erroring_fishers'])

        selection_parameters = None
        if json_info['selection_parameters'] is not None:
            selection_parameters = SelectionParameters(**json_info['selection_parameters'])

        examples_accuracy_info = None
        if json_info['examples_accuracy_info'] is not None:
            examples_accuracy_info = ExamplesAccuracyInfo(**json_info['examples_accuracy_info'])

        subset_infos = []
        for ssi, cf, ef, rki in zip(json_info['subset_infos'], correct_fishers, erroring_fishers, reduce_kept_indices):
            variable_shapes = None
            if ssi['variable_shapes'] is not None:
                variable_shapes = [tf.TensorShape(s) for s in ssi['variable_shapes']]

            correct_component_infos = [
                ComponentSelectivityInfo.from_serializable_dict(s)
                for s in ssi['correct_component_infos']
            ]

            subset_infos.append(
                SubsetInfo(
                    correct_fisher=_replace_nans(cf),
                    erroring_fisher=_replace_nans(ef),
                    correct_component_infos=correct_component_infos,
                    reduce_kept_indices=rki,
                    full_dense_size=ssi['full_dense_size'],
                    variable_shapes=variable_shapes,
                )
            )

        return cls(
            subset_infos=subset_infos,
            selection_parameters=selection_parameters,
            examples_accuracy_info=examples_accuracy_info,
            extra_info=json_info['extra_info'],
            batch_correct_fishers=batch_correct_fishers,
            batch_erroring_fishers=batch_erroring_fishers,
        )


###############################################################################

def _get_p_correct(container: am.PefNmfAnalysisContainer) -> float:
    labeled_mask = ~container.unlabeled_indicator
    return np.mean([
        float(e.is_correctly_labeled())
        for e in container.examples
        if labeled_mask[e.index]
    ])


def _get_p_true(
    container: am.PefNmfAnalysisContainer,
    indicator: np.ndarray,
    ignore_unlabeled: bool,
) -> float:
    if ignore_unlabeled:
        labeled_mask = ~container.unlabeled_indicator
        indicator = indicator[labeled_mask]
    return np.mean(indicator.astype(np.float64))


def get_components_appearing_tuned(
    container: am.PefNmfAnalysisContainer,
    indicator: np.ndarray,
    selection_parameters: SelectionParameters = None,
    ignore_unlabeled: bool = True,
    **kwargs,
) -> List[ComponentSelectivityInfo]:
    # The `indicator` should be 1-d and have the same size
    # as the total number of examples in the container. This
    # includes the unlabeled examples even if ignore_unlabeled=True.
    # However, the values of the indicator on those examples will
    # be ignored.
    
    if selection_parameters is None:
        selection_parameters = SelectionParameters(**kwargs)
    else:
        assert len(kwargs) == 0

    # Make short variable name for brevity.
    p = selection_parameters

    labeled_mask = ~container.unlabeled_indicator
    p_true = _get_p_true(container, indicator, ignore_unlabeled=ignore_unlabeled)

    ret = []
    for nmf_index in range(container.n_nmfs):
        n_components = container.nmfs[nmf_index].W.shape[-1]

        for component_index in range(n_components):

            top_examples = container.get_top_examples_based_on_relative_coefficient(
                nmf_index=nmf_index,
                component_index=component_index,
                factor=p.coeff_factor,
                max_examples=p.max_examples,
            )
            if ignore_unlabeled:
                top_examples = [e for e in top_examples if labeled_mask[e.index]]

            n_examples = len(top_examples)
            if n_examples <= 1:
                continue

            n_true_for_component = sum(int(bool(indicator[e.index])) for e in top_examples)

            fraction = n_true_for_component / n_examples
            p_value = am._binomial_pmf(n_examples, np.arange(n_true_for_component, n_examples + 1), p_true).sum()

            if fraction >= p.frac_threshold and p_value <= p.p_value_threshold:
                info = ComponentSelectivityInfo(
                    nmf_index=nmf_index,
                    component_index=component_index,
                    fraction=fraction,
                    p_value=p_value,
                    # NOTE: This can now sometimes contain unlabeled examples.
                    labeled_examples=top_examples,
                )
                ret.append(info)

    return ret


def get_components_appearing_correct(
    container: am.PefNmfAnalysisContainer,
    selection_parameters: SelectionParameters = None,
    **kwargs,
) -> List[ComponentSelectivityInfo]:
    indicator = container.get_correct_prediction_indicator()
    return get_components_appearing_tuned(
        container=container,
        indicator=indicator,
        selection_parameters=selection_parameters,
        ignore_unlabeled=True,
        **kwargs,
    )

    # if selection_parameters is None:
    #     selection_parameters = SelectionParameters(**kwargs)
    # else:
    #     assert len(kwargs) == 0

    # # Make short variable name for brevity.
    # p = selection_parameters

    # labeled_mask = ~container.unlabeled_indicator
    # p_correct = _get_p_correct(container)

    # ret = []
    # for nmf_index in range(container.n_nmfs):
    #     n_components = container.nmfs[nmf_index].W.shape[-1]

    #     for component_index in range(n_components):

    #         top_examples = container.get_top_examples_based_on_relative_coefficient(
    #             nmf_index=nmf_index,
    #             component_index=component_index,
    #             factor=p.coeff_factor,
    #             max_examples=p.max_examples,
    #         )
    #         top_labeled_examples = [e for e in top_examples if labeled_mask[e.index]]

    #         n_examples = len(top_labeled_examples)
    #         if n_examples <= 1:
    #             continue

    #         n_correctly_labeled = sum(int(e.is_correctly_labeled()) for e in top_labeled_examples)

    #         fraction = n_correctly_labeled / n_examples
    #         p_value = am._binomial_pmf(n_examples, np.arange(n_correctly_labeled, n_examples + 1), p_correct).sum()

    #         if fraction >= p.frac_threshold and p_value <= p.p_value_threshold:
    #             info = ComponentSelectivityInfo(
    #                 nmf_index=nmf_index,
    #                 component_index=component_index,
    #                 fraction=fraction,
    #                 p_value=p_value,
    #                 labeled_examples=top_labeled_examples,
    #             )
    #             ret.append(info)

    # return ret


def group_by_nmf(container: am.PefNmfAnalysisContainer, infos: Sequence[ComponentSelectivityInfo]):
    return [
        [f for f in infos if f.nmf_index == i]
        for i in range(container.n_nmfs)
    ]


def get_estimated_fisher_for_components(nmf, component_indices: Sequence[int]) -> np.ndarray:
    # I think this might actually be theoretically justified in a sense via the definition
    # of the batch Fisher as the expectation over per-example Fishers.
    #
    # W.shape = [n_examples, n_components]
    # H.shape = [n_components, n_features]
    W, H = nmf.W, nmf.H
    component_indices = np.array(list(sorted(component_indices)), dtype=np.int32)
    avg_W = W[:, component_indices].mean(axis=0)
    return H[component_indices, :].T.dot(avg_W)


def get_apparently_tuned_fisher(
    container: am.PefNmfAnalysisContainer,
    indicator: np.ndarray,
    selection_parameters: SelectionParameters = None,
    ignore_unlabeled: bool = True,
    **kwargs,
) -> NmfComponentsFisher:
    if selection_parameters is None:
        selection_parameters = SelectionParameters(**kwargs)
    else:
        assert len(kwargs) == 0

    # See documentation for get_components_appearing_tuned for specifics
    # on how `indicator` works when we have unlabeled examples and
    # ignore_unlabeled=True.
    correct_comp_infos = get_components_appearing_tuned(
        container=container,
        indicator=indicator,
        selection_parameters=selection_parameters,
        ignore_unlabeled=ignore_unlabeled,
    )
    subset_infos = []

    for nmf_index, nmf_infos in enumerate(group_by_nmf(container, correct_comp_infos)):
        nmf = container.nmfs[nmf_index]
        n_components = nmf.W.shape[-1]

        correct_comp_indices = set(f.component_index for f in nmf_infos)
        erroring_comp_indices = set(range(n_components)) - correct_comp_indices

        s_info = SubsetInfo(
            correct_fisher=get_estimated_fisher_for_components(nmf, correct_comp_indices),
            erroring_fisher=get_estimated_fisher_for_components(nmf, erroring_comp_indices),
            correct_component_infos=nmf_infos,
            reduce_kept_indices=nmf.reduce_kept_indices,
            full_dense_size=nmf.full_dense_size,
        )
        subset_infos.append(s_info)

    return NmfComponentsFisher(
        subset_infos=subset_infos,
        selection_parameters=selection_parameters,
    )


def get_apparently_correct_fisher(
    container: am.PefNmfAnalysisContainer,
    selection_parameters: SelectionParameters = None,
    **kwargs,
) -> NmfComponentsFisher:
    indicator = container.get_correct_prediction_indicator()
    return get_apparently_tuned_fisher(
        container=container,
        indicator=indicator,
        selection_parameters=selection_parameters,
        ignore_unlabeled=True,
        **kwargs,
    )

    # if selection_parameters is None:
    #     selection_parameters = SelectionParameters(**kwargs)
    # else:
    #     assert len(kwargs) == 0

    # correct_comp_infos = get_components_appearing_correct(
    #     container,
    #     selection_parameters=selection_parameters,
    # )

    # subset_infos = []

    # for nmf_index, nmf_infos in enumerate(group_by_nmf(container, correct_comp_infos)):
    #     nmf = container.nmfs[nmf_index]
    #     n_components = nmf.W.shape[-1]

    #     correct_comp_indices = set(f.component_index for f in nmf_infos)
    #     erroring_comp_indices = set(range(n_components)) - correct_comp_indices

    #     s_info = SubsetInfo(
    #         correct_fisher=get_estimated_fisher_for_components(nmf, correct_comp_indices),
    #         erroring_fisher=get_estimated_fisher_for_components(nmf, erroring_comp_indices),
    #         correct_component_infos=nmf_infos,
    #         reduce_kept_indices=nmf.reduce_kept_indices,
    #         full_dense_size=nmf.full_dense_size,
    #     )
    #     subset_infos.append(s_info)

    # return NmfComponentsFisher(
    #     subset_infos=subset_infos,
    #     selection_parameters=selection_parameters,
    # )

###############################################################################


def _get_filepaths_of_ncf_files(filepath: str):
    # filepath is the path of h5 file that should match the given selectivities. The
    # selectivity-specfic part of the filename should be removed in this.
    assert filepath.endswith('.h5')
    filepath = os.path.expanduser(filepath)

    basename = os.path.basename(filepath)[:-3]
    dirpath = os.path.dirname(filepath)

    ret = []

    for filename in os.listdir(dirpath):
        *prefix, sel_info, suffix = filename.split('.')
        prefix = '.'.join(prefix)
        if prefix == basename and suffix == 'h5':
            ret.append(os.path.join(dirpath, filename))

    return ret


def print_parameters_and_accuracy_infos_matching_file(filepath: str):
    # filepath is the path of h5 file that should match the given selectivities. The
    # selectivity-specfic part of the filename should be removed in this.
    filepaths = _get_filepaths_of_ncf_files(filepath)

    data = []
    for fp in filepaths:
        with h5py.File(fp, "r") as f:
            json_info = json.loads(f['data'].attrs['json_info'])

        # We assume that these are contained in all the saved files.
        selection_parameters = SelectionParameters(**json_info['selection_parameters'])
        examples_accuracy_info = ExamplesAccuracyInfo(**json_info['examples_accuracy_info'])

        data.append((selection_parameters, examples_accuracy_info))

    for sp, eai in data:
        cf = cu.hly(f'{sp.coeff_factor:.3f}')
        ft = cu.hly(f'{sp.frac_threshold:.3f}')
        pvt = cu.hly(f'{sp.p_value_threshold:.4f}')

        tef = cu.hlc(f'{eai.n_component_examples / eai.n_total_examples:.3f}')
        cea = cu.hlc(f'{eai.component_examples_accuracy:.3f}')
        rea = cu.hlc(f'{eai.remaining_examples_accuracy:.3f}')

        print(f'Coefficient fractor: {cf}   Fraction threshold: {ft}   P-value threshold: {pvt}')
        print(f'    Top examples fraction: {tef}')
        print(f'    Top examples accuracy: {cea}')
        print(f'    Remaining examples accuracy: {rea}')
        print()
