"""Common code and utilities for the coeff_kl_relationship.py script and its outputs."""
import dataclasses
import os
from typing import List, Optional

import h5py
import numpy as np
from scipy import stats
from sklearn.feature_selection import mutual_info_regression
import tensorflow as tf


from em.util import hdf5_util


save_h5_ds = hdf5_util.save_h5_ds
load_h5_ds = hdf5_util.load_h5_ds

"""
Output structure:
    - Script arguments (e.g., component_index, ...)
        - Includes link back to pefs and nmf files.
        - Can maybe put some stuff in a dict, some definitely should have own fields.
    - W
    - Evaluation indices.
    - og logits and labels for evaluation indices.
    - Runs:
        - lmbda
        - delta
        - logits for evaluation indices
"""

##########################################################################


@dataclasses.dataclass
class MetricsInfo:
    mutual_info: Optional[float] = None
    spearman: Optional[float] = None
    pearson: Optional[float] = None
    
    def log(self):
        if self.mutual_info is not None:
            print(f'Mutual Info: {self.mutual_info}')
        if self.spearman is not None:
            print(f'Spearman Corr: {self.spearman}')
        if self.pearson is not None:
            print(f'Pearson Corr: {self.pearson}')
        print('')


##########################################################################

@dataclasses.dataclass
class RunOutput:
    """Output of an individual evaluation run."""
    lmbda: float
    delta: float

    # These are restricted to the evaluation examples only. See comment
    # on some fields of OutputForComponent for more info.
    #
    # shape = [n_eval_ex, n_classes], dtype=float32
    logits: float


@dataclasses.dataclass
class OutputForComponent:
    component_index: int

    runs: List[RunOutput]

    # shape = [n_eval_ex], dtype=int32
    evaluation_ex_indices: np.ndarray

    # Note that the following have been mapped and filtered using
    # the evaluation_ex_indices. For example,
    #   self.W = nmf.W[evaluation_ex_indices, :]
    #
    # shape = [n_eval_ex, n_components], dtype=float32
    W: np.ndarray
    # shape = [n_eval_ex], dtype=int32
    labels: np.ndarray
    # shape = [n_eval_ex, n_classes], dtype=float32
    og_logits: np.ndarray

    # Filepaths of inputs.
    pef_path: str
    nmf_path: str
    retaining_fisher_path: str

    model: str
    tokenizer: str

    def compute_metrics(
        self,
        component_index: int,
        *,
        mutual_info: bool = False,
        spearman: bool = True,
        pearson: bool = True,
    ):
        kls = self._compute_kl(np.concatenate([r.logits for r in self.runs], axis=0))
        coeffs = np.concatenate(len(self.runs) * [self.W[:, component_index]], axis=0)

        kwargs = {}
        if mutual_info:
            kwargs['mutual_info'] = mutual_info_regression(coeffs[:, None], kls)[0]
        if spearman:
            kwargs['spearman'] = stats.spearmanr(coeffs, kls)[0]
        if pearson:
            kwargs['pearson'] = stats.pearsonr(coeffs, kls)[0]

        return MetricsInfo(**kwargs)

    #

    def _compute_kls_for_tops(self, top_k: int, component_index: Optional[int] = None,
                              *, fisher_norm_corrections: Optional[np.ndarray] = None):
        # TODO: Need to filter out runs or something that did not meet the targeted KL-range.
        if component_index is None:
            component_index = self.component_index

        top_inds = np.argsort(-self.W[:, component_index])[:top_k]

        top_kls, kls = [], []
        for r in self.runs:
            r_kls = self._compute_kl(r.logits)
            if fisher_norm_corrections is not None:
                r_kls /= fisher_norm_corrections
            r_top_kls = r_kls[top_inds]
            kls.append(np.mean(r_kls))
            top_kls.append(np.mean(r_top_kls))

        return np.array(top_kls), np.array(kls)

    def compute_kl_ratio_as_tuple__avg_then_ratio(self, top_k: int, component_index: Optional[int] = None,
                                                  *, fisher_norm_corrections: Optional[np.ndarray] = None):
        top_kls, kls = self._compute_kls_for_tops(top_k, component_index, fisher_norm_corrections=fisher_norm_corrections)
        return np.mean(top_kls), np.mean(kls)

    def compute_kl_ratioe__ratio_then_geom_avg(self, top_k: int, component_index: Optional[int] = None,
                                               *, fisher_norm_corrections: Optional[np.ndarray] = None):
        # TODO: Need to filter out runs or something that did not meet the targeted KL-range.
        top_kls, kls = self._compute_kls_for_tops(top_k, component_index, fisher_norm_corrections=fisher_norm_corrections)
        return stats.gmean(top_kls / kls)

    #

    def _compute_kl(self, logits):
        assert logits.shape[0] % self.og_logits.shape[0] == 0
        n_repeats = logits.shape[0] // self.og_logits.shape[0]
        og_logits = np.concatenate(n_repeats * [self.og_logits], axis=0)
        return tf.keras.losses.kl_divergence(tf.math.softmax(logits), tf.math.softmax(og_logits)).numpy()

    def save(self, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "w") as f:
            # Save non-array data.
            attrs = f.create_group('attrs').attrs
            attrs['component_index'] = self.component_index

            attrs['pef_path'] = self.pef_path
            attrs['nmf_path'] = self.nmf_path
            attrs['retaining_fisher_path'] = self.retaining_fisher_path

            attrs['model'] = self.model
            attrs['tokenizer'] = self.tokenizer

            # Save array data.
            save_h5_ds(f, 'data/evaluation_ex_indices', self.evaluation_ex_indices)

            save_h5_ds(f, 'data/W', self.W)
            save_h5_ds(f, 'data/labels', self.labels)
            save_h5_ds(f, 'data/og_logits', self.og_logits)

            # Save runs.
            for i, run in enumerate(self.runs):
                key = f'data/runs/{i}'
                save_h5_ds(f, f'{key}/logits', run.logits)
                f[key].attrs['lmbda'] = run.lmbda
                f[key].attrs['delta'] = run.delta

    @classmethod
    def load(cls, filepath: str, *, include_W: bool = True):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            attrs = f['attrs'].attrs
            ret = cls(
                evaluation_ex_indices=load_h5_ds(f['data/evaluation_ex_indices']),
                #
                W=load_h5_ds(f['data/W']) if include_W else None,
                labels=load_h5_ds(f['data/labels']),
                og_logits=load_h5_ds(f['data/og_logits']),
                #
                component_index=attrs['component_index'],
                #
                pef_path=attrs['pef_path'],
                nmf_path=attrs['nmf_path'],
                retaining_fisher_path=attrs['retaining_fisher_path'],
                #
                model=attrs['model'],
                tokenizer=attrs['tokenizer'],
                #
                runs=[],
            )

            i = 0
            while True:
                key = f'data/runs/{i}'
                if key not in f:
                    break
                ret.runs.append(
                    RunOutput(
                        lmbda=f[key].attrs['lmbda'],
                        delta=f[key].attrs['delta'],
                        logits=load_h5_ds(f[f'{key}/logits']),
                    )
                )
                i += 1

            return ret
