"""Stuff for guided ablations and controls."""
import dataclasses
import os
from typing import List, Optional, Sequence, Tuple

import h5py
import numpy as np
import tensorflow as tf

from em.evaluation import tf_metrics
from em.util import hdf5_util

from em.projects.pi.exps import kl_targeting

from . import ablation_exp_util


save_h5_ds = hdf5_util.save_h5_ds
load_h5_ds = hdf5_util.load_h5_ds


GRADIENT_TARGETS = ('kl', 'loss')
RANDOM_EXAMPLE_SELECTIONS = ('same_preds', 'uniform')
ABLATING_VARIABLE_STYLES = ('fixed_offset', 'gradient')

ABLATION_EXP_TYPES = ('component_examples', 'component_examples_H',
                      'random_examples', 'random_examples_H')


@dataclasses.dataclass
class ExperimentHelper1:
    exp: ablation_exp_util.Experiment1

    component_index: int

    kl_target_range: Tuple[float, float]

    # TODO: Have some slightly more sophisticated way to set this value.
    n_selected_examples: int

    # Maybe can be "kl", "loss"
    gradient_target: str = "kl"

    # Whether to ablate in the direction of decreasing or increasing gradient.
    negate_gradient: bool = False

    # Maybe can be "same_preds", "uniform"
    random_example_selection: str = "same_preds"

    ablating_variable_style: str = "fixed_offset"

    def __post_init__(self):
        assert self.gradient_target in GRADIENT_TARGETS
        assert self.random_example_selection in RANDOM_EXAMPLE_SELECTIONS
        assert self.ablating_variable_style in ABLATING_VARIABLE_STYLES

        self.eval_ctx = self.exp.mc.get_evaluation_context()

        self.top_example_inds = self.exp.get_top_example_indices(self.component_index)[:self.n_selected_examples]
        self.top_examples_fisher = self._compute_dense_fisher(self.top_example_inds)
        self.top_examples_gradient = self._compute_normalized_gradient(self.top_example_inds)

        self.resample_random_examples()

    ######################################################################

    def resample_random_examples(self):
        self.random_example_inds = self._make_random_example_inds()
        self.random_examples_fisher = self._compute_dense_fisher(self.random_example_inds)
        self.random_examples_gradient = self._compute_normalized_gradient(self.random_example_inds)

    def _make_random_example_inds(self):
        exp = self.exp

        if self.random_example_selection == "uniform":
            return exp.random_example_indices(self.n_selected_examples)

        elif self.random_example_selection == 'same_preds':
            n_classes = exp.predicted_logits.shape[-1]
            ret = []
            for label in range(n_classes):
                n_examples_with_label = (exp.predictions[self.top_example_inds] == label).astype(np.int64).sum()
                ret.append(
                    np.random.permutation(np.nonzero(exp.predictions == label)[0])[:n_examples_with_label])
            ret = np.concatenate(ret, axis=0)
            return np.random.permutation(ret)

        else:
            raise ValueError(self.random_example_selection)

    def _compute_dense_fisher(self, example_inds: np.ndarray):
        return self.exp.compute_fisher(example_inds)

    def _compute_normalized_gradient(self, example_inds: np.ndarray):
        exp = self.exp

        if self.gradient_target == 'loss':
            grads = exp.compute_loss_gradient(example_inds)
        elif self.gradient_target == 'kl':
            grads = exp.compute_kl_gradient(example_inds, allow_recompile=True)
        else:
            raise ValueError(self.gradient_target)

        if self.negate_gradient:
            grads = [-g for g in grads]

        # Normalization.
        inv_norm = tf.math.rsqrt(tf.reduce_sum([tf.reduce_sum(tf.square(g)) for g in grads]))
        grads = [inv_norm * g for g in grads]

        return grads

    ######################################################################

    def get_component_examples_ablator(self) -> 'Ablator1':
        return Ablator1(
            helper=self,
            example_inds=self.top_example_inds,
            examples_fisher=self.top_examples_fisher,
            examples_gradient=self.top_examples_gradient,
        )

    def get_random_examples_ablator(self) -> 'Ablator1':
        return Ablator1(
            helper=self,
            example_inds=self.random_example_inds,
            examples_fisher=self.random_examples_fisher,
            examples_gradient=self.random_examples_gradient,
        )

    ######################################################################

    def get_component_H_ablator(self) -> 'Ablator1':
        return Ablator1(
            helper=self,
            example_inds=self.top_example_inds,
            examples_fisher=self.exp.mc.make_fisher_for_components([self.component_index]),
            examples_gradient=self.top_examples_gradient,
        )

    def get_random_examples_H_ablator(self) -> 'Ablator1':
        return Ablator1(
            helper=self,
            example_inds=self.random_example_inds,
            examples_fisher=self.exp.mc.make_fisher_for_components([self.component_index]),
            examples_gradient=self.random_examples_gradient,
        )


@dataclasses.dataclass
class Ablator1:
    helper: ExperimentHelper1

    example_inds: np.ndarray
    examples_fisher: List[tf.Tensor]
    examples_gradient: List[tf.Tensor]

    # in_group_eval_example_inds: np.ndarray
    # out_group_eval_example_inds: np.ndarray

    def __post_init__(self):
        self.exp = self.helper.exp

    def _get_ablating_variables(self, delta: float):
        style = self.helper.ablating_variable_style
        if style == 'fixed_offset':
            return self.exp.apply_sign_guide(self.examples_gradient, delta)
        elif style == "gradient":
            return self.exp.apply_gradient(self.examples_gradient, delta)
        else:
            raise ValueError(style)

    def _create_model(self, delta: float, lmbda: float):
        ablating_variables = self._get_ablating_variables(delta)
        gen = self.exp.stream_merge(
            ablating_variables=ablating_variables,
            ablating_fisher=self.examples_fisher,
            coefficients=[(1 - lmbda, lmbda)],
        )
        for coefficients in gen:
            return self.exp.output_model

    def _get_kl_fn(self):
        helper = self.helper

        def kl_fn(delta: float, lmbda: float):
            model = self._create_model(delta, lmbda)
            return helper.eval_ctx.evaluate(model, self.example_inds).kl()

        return kl_fn

    def find_model(self, max_iters: int = 25, *, max_delta: float = 3):
        targeter = kl_targeting.GenericKlTargeter(
            kl_fn=self._get_kl_fn(),
            kl_range=self.helper.kl_target_range,
            delta_mag_range=[1e-5, max_delta],
        )
        targeter.search(max_iters=max_iters)
        return self.exp.output_model


###############################################################################
###############################################################################


# def compute_prediction_selectivity(results, top_k: int):
#     coeffs = results.W[:, results.component_index]
#     top_inds = np.argsort(-coeffs)[:top_k]
#     top_logits = results.og_logits[top_inds]
#     preds = np.argmax(top_logits, axis=-1)
#     pred_rates = [
#         (preds == i).mean()
#         for i in range(top_logits.shape[-1])
#     ]
#     return max(pred_rates)

@dataclasses.dataclass
class KlSelectivityInfo:
    selected_examples_kl: float
    all_examples_kl: float

    def log(self):
        print(f'Selection KL: {self.selected_examples_kl}')
        print(f'Full Data KL: {self.all_examples_kl}')
        print()

    def ratio(self) -> float:
        return self.selected_examples_kl / self.all_examples_kl

    @classmethod
    def mean(cls, infos: Sequence['KlSelectivityInfo']) -> 'KlSelectivityInfo':
        return cls(
            selected_examples_kl=np.mean([s.selected_examples_kl for s in infos]),
            all_examples_kl=np.mean([s.all_examples_kl for s in infos]),
        )


@dataclasses.dataclass
class OutputForAblator:
    selected_example_indices: np.ndarray
    # shape = [n_runs, n_examples, n_classes]
    output_logits: np.ndarray

    def _save_to_group(self, f, group_name):
        save_h5_ds(f, os.path.join(group_name, 'selected_example_indices'), self.selected_example_indices)
        save_h5_ds(f, os.path.join(group_name, 'output_logits'), self.output_logits)

    @classmethod
    def _load_from_group(cls, f, group_name):
        if group_name not in f:
            return None
        return cls(
            selected_example_indices=load_h5_ds(f[os.path.join(group_name, 'selected_example_indices')]),
            output_logits=load_h5_ds(f[os.path.join(group_name, 'output_logits')]),
        )


@dataclasses.dataclass
class OutputForComponent:
    component_index: int

    ablating_variable_style: str
    kl_target_range: Tuple[float, float]

    W: np.ndarray
    # shape = [n_ex], dtype=int32
    labels: np.ndarray
    # shape = [n_ex, n_classes], dtype=float32
    og_logits: np.ndarray

    # Filepaths of inputs.
    pef_path: str
    nmf_path: str
    retaining_fisher_path: str

    model: str
    tokenizer: str

    component_top_fisher_ablation: Optional[OutputForAblator] = None
    component_H_ablation: Optional[OutputForAblator] = None
    
    random_examples_ablations: Sequence[OutputForAblator] = ()
    random_examples_H_ablations: Sequence[OutputForAblator] = ()

    ##########################################################################
    ##########################################################################

    def _compute_kl_selectivities(
        self,
        ablator_output: OutputForAblator,
        selected_example_indices: Optional[np.ndarray] = None,
    ) -> List[KlSelectivityInfo]:
        if selected_example_indices is None:
            selected_example_indices = ablator_output.selected_example_indices
        ret = []
        for logits in ablator_output.output_logits:
            assert logits.shape[0] % self.og_logits.shape[0] == 0
            ret.append(
                KlSelectivityInfo(
                    selected_examples_kl=tf.keras.losses.kl_divergence(
                        tf.math.softmax(logits[selected_example_indices]),
                        tf.math.softmax(self.og_logits[selected_example_indices])).numpy().mean(),
                    all_examples_kl=tf.keras.losses.kl_divergence(
                        tf.math.softmax(logits), tf.math.softmax(self.og_logits)).numpy().mean(),
                )
            )
        return ret

    def _get_kl_range_met_mask(
        self,
        ablator_output: OutputForAblator,
        min_kl: float,
        max_kl: float,
    ) -> np.ndarray:
        selected_example_indices = ablator_output.selected_example_indices
        mask = []
        for logits in ablator_output.output_logits:
            assert logits.shape[0] % self.og_logits.shape[0] == 0
            selected_examples_kl = tf.keras.losses.kl_divergence(
                tf.math.softmax(logits[selected_example_indices]),
                tf.math.softmax(self.og_logits[selected_example_indices])
            ).numpy().mean()
            mask.append(min_kl <= selected_examples_kl <= max_kl)
        return np.array(mask, dtype=bool)

    #################################################################

    def get_component_top_fisher_kl_selectivities(self) -> List[KlSelectivityInfo]:
        return self._compute_kl_selectivities(self.component_top_fisher_ablation)

    def get_random_examples_kl_selectivities(self) -> List[KlSelectivityInfo]:
        ret = []
        for abl in self.random_examples_ablations:
            ret.extend(self._compute_kl_selectivities(abl))
        return ret

    def get_component_H_kl_selectivities(self) -> List[KlSelectivityInfo]:
        return self._compute_kl_selectivities(self.component_H_ablation)

    def get_random_examples_H_selectivities(
        self,
        selected_example_indices: Optional[np.ndarray] = None,
    ) -> List[KlSelectivityInfo]:
        ret = []
        for abl in self.random_examples_H_ablations:
            ret.extend(self._compute_kl_selectivities(abl, selected_example_indices))
        return ret
        
    ##########################################################################
    ##########################################################################

    # def _compute_kl_selectivities(self, ablator_output: OutputForAblator) -> List[KlSelectivityInfo]:
    #     selected_example_indices = ablator_output.selected_example_indices
    #     ret = []
    #     for logits in ablator_output.output_logits:
    #         assert logits.shape[0] % self.og_logits.shape[0] == 0
    #         ret.append(
    #             KlSelectivityInfo(
    #                 selected_examples_kl=tf.keras.losses.kl_divergence(
    #                     tf.math.softmax(logits[selected_example_indices]),
    #                     tf.math.softmax(self.og_logits[selected_example_indices])).numpy().mean(),
    #                 all_examples_kl=tf.keras.losses.kl_divergence(
    #                     tf.math.softmax(logits), tf.math.softmax(self.og_logits)).numpy().mean(),
    #             )
    #         )
    #     return ret

    # def get_component_top_fisher_kl_selectivities(self) -> List[KlSelectivityInfo]:
    #     return self._compute_kl_selectivities(self.component_top_fisher_ablation)

    # def get_component_H_kl_selectivities(self) -> List[KlSelectivityInfo]:
    #     return self._compute_kl_selectivities(self.component_H_ablation)

    # def get_random_examples_kl_selectivities(self) -> List[KlSelectivityInfo]:
    #     ret = []
    #     for abl in self.random_examples_ablations:
    #         ret.extend(self._compute_kl_selectivities(abl))
    #     return ret

    ##########################################################################
    ##########################################################################

    def _compute_spearmanr(self, ablator_output: OutputForAblator, component_index: int):
        kls = np.concatenate([
            tf.keras.losses.kl_divergence(
                tf.math.softmax(logits), tf.math.softmax(self.og_logits)).numpy()
            for logits in ablator_output.output_logits
        ], axis=0)
        coeffs = np.concatenate(ablator_output.output_logits.shape[0] * [self.W[:, component_index]], axis=0)
        return tf_metrics.spearmanr_vv(tf.cast(coeffs, tf.float32), tf.cast(kls, tf.float32)).numpy()

    def save(self, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "w") as f:
            # Save non-array data.
            attrs = f.create_group('attrs').attrs
            attrs['component_index'] = self.component_index

            attrs['pef_path'] = self.pef_path
            attrs['nmf_path'] = self.nmf_path
            attrs['retaining_fisher_path'] = self.retaining_fisher_path

            attrs['model'] = self.model
            attrs['tokenizer'] = self.tokenizer

            attrs['ablating_variable_style'] = self.ablating_variable_style
            attrs['kl_target_range__min'] = self.kl_target_range[0]
            attrs['kl_target_range__max'] = self.kl_target_range[1]

            # Save array data.
            save_h5_ds(f, 'data/W', self.W)
            save_h5_ds(f, 'data/labels', self.labels)
            save_h5_ds(f, 'data/og_logits', self.og_logits)

            # Save runs data.
            if self.component_top_fisher_ablation is not None:
                self.component_top_fisher_ablation._save_to_group(f, 'data/comp_ex_ablation')

            if self.component_H_ablation is not None:
                self.component_H_ablation._save_to_group(f, 'data/comp_H_ablation')

            for i, abl in enumerate(self.random_examples_ablations):
                abl._save_to_group(f, f'data/rand_ex_ablations/{i}')

            for i, abl in enumerate(self.random_examples_H_ablations):
                abl._save_to_group(f, f'data/rand_ex_H_ablations/{i}')

    @classmethod
    def load(cls, filepath: str, *, include_W: bool = True):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            attrs = f['attrs'].attrs

            def load_random_ex_abl_list(name: str):
                ret = []
                i = 0
                while True:
                    key = f'data/{name}/{i}'
                    if key not in f:
                        break
                    ret.append(
                        OutputForAblator._load_from_group(f, key))
                    i += 1
                return ret

            return cls(
                component_index=attrs['component_index'],
                #
                ablating_variable_style=attrs['ablating_variable_style'],
                kl_target_range=(attrs['kl_target_range__min'], attrs['kl_target_range__max']),
                #
                component_top_fisher_ablation=OutputForAblator._load_from_group(f, 'data/comp_ex_ablation'),
                component_H_ablation=OutputForAblator._load_from_group(f, 'data/comp_H_ablation'),
                random_examples_ablations=load_random_ex_abl_list('rand_ex_ablations'),
                random_examples_H_ablations=load_random_ex_abl_list('rand_ex_H_ablations'),
                #
                W=load_h5_ds(f['data/W']) if include_W else None,
                labels=load_h5_ds(f['data/labels']),
                og_logits=load_h5_ds(f['data/og_logits']),
                #
                pef_path=attrs['pef_path'],
                nmf_path=attrs['nmf_path'],
                retaining_fisher_path=attrs['retaining_fisher_path'],
                #
                model=attrs['model'],
                tokenizer=attrs['tokenizer'],
            )
