"""Stuff for experiments about HANS ablation."""
import dataclasses
import json
import os
from typing import List, Sequence

import h5py
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, PreTrainedTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.merging import merging
from em.util import hdf5_util
from em.util import hf_util

from em.projects.ll import hans_components_context as HCC
from em.projects.ll import hans_merging_context as HMC

from em.util.color_util import cu


###############################################################################

# TODO: Hardcoding some stuff now, maybe make settable later.
EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################


@dataclasses.dataclass
class AblationEvaluationResult:
    coefficients: Sequence[int]
    hans_results: HCC.HansLoneEvaluationResults
    mnli_results: HMC.MnliEvaluationResults

    def print_summary(self, verbosity=1, *, sorted_hans_example_indices=None):
        def _rpad(s, n=4 + 16):
            return f'{s}{n * " "}'[:n]

        if verbosity == 0:
            return

        assert verbosity > 0

        print(cu.hly(self.coefficients))
        print(cu.hlg('  MNLI:'))
        print(cu.hlb(_rpad('    Full Dataset')), f'{self.mnli_results.kl():.6f}    {self.mnli_results.acc():.3f}')

        entailing_inds, = np.nonzero(self.hans_results.labels == 0)
        non_entailing_inds, = np.nonzero(self.hans_results.labels == 1)

        print(cu.hlg('  HANS:'))
        print(cu.hlb(_rpad('    Full Dataset')), f'{self.hans_results.kl():.6f}    {self.hans_results.acc():.3f}')
        print(cu.hlb(_rpad('    Entailing')), f'{self.hans_results.kl_for_examples(entailing_inds):.6f}    {self.hans_results.acc_for_examples(entailing_inds):.3f}')
        print(cu.hlb(_rpad('    Non-Entailing')), f'{self.hans_results.kl_for_examples(non_entailing_inds):.6f}    {self.hans_results.acc_for_examples(non_entailing_inds):.3f}')
        
        if sorted_hans_example_indices is not None:
            for n_ex in [12, 24, 48]:
                ex_inds = sorted_hans_example_indices[:n_ex]
                print(cu.hlb(_rpad(f'    Top {n_ex} Ex.')), f'{self.hans_results.kl_for_examples(ex_inds):.6f}    {self.hans_results.acc_for_examples(ex_inds):.3f}')

        print()


###############################################################################


@dataclasses.dataclass
class AblationExperimentConfig:
    model_number: int

    component_index: int

    n_mnli_examples: int

    n_top_examples_sign_guide: int

    n_coefficients: int
    fisher_floor: float = 1e-7

    sequence_length: int = 64

    dense_fisher_batch_size: int = 4
    get_loss_gradient_batch_size: int = 32

    fisher_pattern: str = "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.h5"
    fishers_dir: str = FISHERS_DIR


@dataclasses.dataclass
class AblationExperiment:
    config: AblationExperimentConfig

    tokenizer: PreTrainedTokenizer
    hacc: HCC.HansLoneComponentContext
    mc: HCC.HansLoneModelContext

    def __post_init__(self):
        self.component_index = self.config.component_index

        self.sorted_example_indices = self.mc.sort_example_indices_for_component(self.component_index)
        self.eval_ctx = self.hacc.get_evaluation_context(og_logits=self.mc.container.predicted_logits)

        self.mnli_eval_ctx = HMC.MnliEvaluationContext(
            n_examples=self.config.n_mnli_examples,
            tokenizer=self.hacc.tokenizer
        )
        self.mnli_eval_ctx.set_up_og_data(self.mc.model)

        self.retaining_variables = list(self.mc.variables)

        self._sign_guide = None
        self._top_component_examples_full_fisher = None
        self._full_dataset_fisher = None
        self._component_h_fisher = None
        self._compl_component_h_fisher = None
        self._top_component_examples_loss_gradient = None

    def get_ablating_variables(self, delta: float):
        return HMC.apply_sign_guide(self.retaining_variables, self.sign_guide, delta)

    @property
    def sign_guide(self):
        if self._sign_guide is None:
            self._sign_guide = self._get_sign_guide()
        return self._sign_guide

    @property
    def top_component_examples_full_fisher(self):
        if self._top_component_examples_full_fisher is None:
            self._top_component_examples_full_fisher = self._get_top_component_examples_full_fisher()
        return self._top_component_examples_full_fisher

    @property
    def full_dataset_fisher(self):
        if self._full_dataset_fisher is None:
            self._full_dataset_fisher = self._load_full_dataset_fisher()
        return self._full_dataset_fisher

    @property
    def component_h_fisher(self):
        if self._component_h_fisher is None:
            self._component_h_fisher = self.mc.make_fisher_for_components([self.component_index])
        return self._component_h_fisher

    @property
    def compl_component_h_fisher(self):
        if self._compl_component_h_fisher is None:
            self._compl_component_h_fisher = self.mc.make_fisher_for_components(
                set(range(self.mc.n_components)) - {self.component_index}
            )
        return self._compl_component_h_fisher

    @property
    def top_component_examples_loss_gradient(self):
        # This is normalized to unit l2 norm.
        if self._top_component_examples_loss_gradient is None:
            self._top_component_examples_loss_gradient = self._get_top_component_examples_loss_gradient()
        return self._top_component_examples_loss_gradient

    ############################################################

    def _get_sg_ds(self):
        sg_ds = em_datasets.load('hans/lexical_overlap_ne_with_flipped', split='validation',
                                 sequence_length=self.config.sequence_length, tokenizer=self.mc.tokenizer)
        sg_ds = HMC.get_ds_by_example_indices(sg_ds, self.sorted_example_indices[:self.config.n_top_examples_sign_guide])
        return sg_ds

    def _get_top_component_examples_full_fisher(self) -> Sequence[tf.Tensor]:
        sg_ds = self._get_sg_ds().batch(self.config.dense_fisher_batch_size)
        return diagonal.compute_fisher_for_model(self.mc.model, sg_ds, variables=self.mc.variables)

    def _get_sign_guide(self):
        sg_ds = self._get_sg_ds().batch(self.config.get_loss_gradient_batch_size)
        return HMC.get_loss_gradient(
            self.mc.model,
            self.mc.variables,
            sg_ds,
        )

    def _get_top_component_examples_loss_gradient(self):
        grads = self.sign_guide
        mag = merging._l2_norm_of_fisher(grads)
        return [g / mag for g in grads]

    def _load_full_dataset_fisher(self) -> Sequence[tf.Tensor]:
        return diagonal.DiagonalFisher.load(
            os.path.join(self.config.fishers_dir, self.config.fisher_pattern.format(model_number=self.config.model_number))
        ).fishers

    ############################################################

    """
    Two main types of ablation runs:
        - Merging-based.
        - Gradient-following (i.e. gradient descent or gradient ascent).

    Merging-based run types:
        - H-based ablating fisher
        - Top example based ablating fisher
        -
        - [maybe ignore this option in the first pass] H-based retaining fisher, excludes component.
        - Full HANS dataset dense retaining fisher.

    Gradient-following run types:
        - Pure gradient following.
        - [Might not make sense, don't have method yet] Gradient following making use of (full HANS
           dataset dense) retaining fisher.

    """

    # TODO: Some automatic coefficient/delta selection process.

    def perform_merge_based_ablation_run(
        self,
        retaining_fisher: Sequence[tf.Tensor],
        ablating_fisher: Sequence[tf.Tensor],
        delta: float,
    ) -> List[AblationEvaluationResult]:
        output_model = self.mc.load_model()
        output_variables = hf_util.get_all_variables(output_model)

        variables_to_merge = [self.retaining_variables, self.get_ablating_variables(delta)]
        fishers_to_merge = [retaining_fisher, ablating_fisher]

        norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]

        results = []
        for coefficients in merging.create_pairwise_grid_coeffs(self.config.n_coefficients):
            merging._merge_with_coeffs(
                output_variables,
                variables_to_merge,
                coefficients=coefficients,
                fishers=fishers_to_merge,
                fisher_floor=self.config.fisher_floor,
                favor_target_model=True,
                normalization_constants=norm_constants,
            )
            results.append(
                AblationEvaluationResult(
                    coefficients=coefficients,
                    hans_results=self.eval_ctx.evaluate(output_model),
                    mnli_results=self.mnli_eval_ctx.evaluate(output_model),
                )
            )
            results[-1].print_summary(
                sorted_hans_example_indices=self.sorted_example_indices
            )

        return results

    def perform_gradient_following_run(
        self,
        delta: float,
    ) -> List[AblationEvaluationResult]:
        output_model = self.mc.load_model()
        output_variables = hf_util.get_all_variables(output_model)
        grads = self.top_component_examples_loss_gradient

        results = []
        for coefficients in merging.create_pairwise_grid_coeffs(self.config.n_coefficients):
            _, coeff = coefficients

            for outv, ogv, grad in zip(output_variables, self.retaining_variables, grads):
                outv.assign(ogv + delta * coeff * grad)

            results.append(
                AblationEvaluationResult(
                    coefficients=coefficients,
                    hans_results=self.eval_ctx.evaluate(output_model),
                    mnli_results=self.mnli_eval_ctx.evaluate(output_model),
                )
            )
            results[-1].print_summary(
                sorted_hans_example_indices=self.sorted_example_indices
            )

        return results

    ############################################################

    def make_metadata_for_saving(self, **kwargs):
        ret = {
            **kwargs,
            'config': dataclasses.asdict(self.config),
        }
        return ret


###############################################################################


def _save_h5_ds(group, name, ndarray):
    ds = group.create_dataset(name, ndarray.shape, dtype=ndarray.dtype)
    hdf5_util.set_h5_ds(ds, ndarray)
    return ds


def _read_h5_ds(group, name):
    ds = group[name]
    a = np.zeros(ds.shape, dtype=ds.dtype)
    ds.read_direct(a)
    return a


def save_results_list_to_h5(filepath: str, results: Sequence[AblationEvaluationResult], metadata):
    r0 = results[0]

    with h5py.File(os.path.expanduser(filepath), "w") as f:

        data = f.create_group('data')
        data.attrs['metadata'] = json.dumps(metadata)

        _save_h5_ds(data, 'coefficients', np.array([r.coefficients for r in results], dtype=np.float32))

        # HANS
        #
        hans = data.create_group('hans')

        _save_h5_ds(hans, 'labels', r0.hans_results.labels)
        _save_h5_ds(hans, 'og_logits', r0.hans_results.og_logits)

        # This will be a rank-3 tensor with dims [n_merge_coeffs, n_examples, n_classes]
        _save_h5_ds(hans, 'logits_list', np.stack([r.hans_results.logits for r in results], axis=0))
        # indicators

        # MNLI
        #
        mnli = data.create_group('mnli')

        _save_h5_ds(mnli, 'labels', r0.mnli_results.labels)
        _save_h5_ds(mnli, 'og_logits', r0.mnli_results.og_logits)

        # This will be a rank-3 tensor with dims [n_merge_coeffs, n_examples, n_classes]
        _save_h5_ds(mnli, 'logits_list', np.stack([r.mnli_results.logits for r in results], axis=0))


def read_results_list_from_h5(filepath: str, components_context: HCC.HansLoneComponentContext):
    ret = []

    with h5py.File(os.path.expanduser(filepath), "r") as f:
        metadata = f['data'].attrs['metadata']
        metadata = json.loads(metadata)

        coefficient_set = _read_h5_ds(f, 'data/coefficients')

        hans_labels = _read_h5_ds(f, 'data/hans/labels')
        hans_og_logits = _read_h5_ds(f, 'data/hans/og_logits')
        hans_logits_list = _read_h5_ds(f, 'data/hans/logits_list')

        mnli_labels = _read_h5_ds(f, 'data/mnli/labels')
        mnli_og_logits = _read_h5_ds(f, 'data/mnli/og_logits')
        mnli_logits_list = _read_h5_ds(f, 'data/mnli/logits_list')

    metadata['config'] = AblationExperimentConfig(**metadata['config'])

    for i, coeffs in enumerate(coefficient_set):
        hans_results = HCC.HansLoneEvaluationResults(
            components_context=components_context,
            labels=hans_labels,
            logits=hans_logits_list[i],
            og_logits=hans_og_logits,
        )
        mnli_results = HMC.MnliEvaluationResults(
            labels=mnli_labels,
            logits=mnli_logits_list[i],
            og_logits=mnli_og_logits,
        )
        ret.append(
            AblationEvaluationResult(
                coefficients=coeffs,
                hans_results=hans_results,
                mnli_results=mnli_results,
            )
        )

    return ret, metadata
