"""Computing the Fisher given a single example."""
import os
import dataclasses
import functools
from typing import Dict, Optional, Sequence, Tuple, Union

import h5py
import numpy as np
import tensorflow as tf
from transformers import PreTrainedTokenizer

from em.fishers import diagonal
from em.util import hdf5_util
from em.util import hf_util

# typdefs:

# (indices, values, dense_shape)
SparseTensorInfo = Tuple[np.ndarray, np.ndarray, Tuple[int, int]]


###############################################################################


def _to_coo_indices(fisher_indices):
    first_index = np.ones_like(fisher_indices) * np.arange(fisher_indices.shape[0], dtype=fisher_indices.dtype)[:, None]
    return np.stack([
        first_index.reshape([-1]),
        fisher_indices.reshape([-1]),
    ], axis=0)


def _save_h5_ds(group, name, ndarray):
    print(name, ndarray.shape, ndarray.dtype)
    ds = group.create_dataset(name, ndarray.shape, dtype=ndarray.dtype)
    hdf5_util.set_h5_ds(ds, ndarray)
    return ds


@dataclasses.dataclass
class PerExampleFlatFishers:
    """Utility class for interacting with saved per-example Fishers.

    The files will probably be created by scripts1/data_gen/save_per_example_fishers_to_disk.py.
    """
    input_ids: np.ndarray
    fishers: np.ndarray

    dense_fisher_norms: np.ndarray

    labels: np.ndarray
    predicted_logits: np.ndarray

    fisher_dense_size: int

    # Having this be something other than None indicates
    # that we are using dynamically sparse fishers. In that case
    # the fishers attribute will be the fisher values.
    fisher_indices: Optional[np.ndarray] = None

    # These should only be in the saved h5 file is the flavor
    # was "sparse_dynamic_metric_derived".
    sq_parameter_deltas: Optional[np.ndarray] = None
    dense_metric_derived_norms: Optional[np.ndarray] = None

    # Stores data fields other than those here by default. Although
    # we support accessing them directly via attributes, this allows
    # them to be stored in a consistent place.
    extra_data_fields: Optional[Dict[str, np.ndarray]] = None

    def __post_init__(self):
        if self.extra_data_fields is None:
            self.extra_data_fields = {}

    def __getattr__(self, field: str):
        # TODO: See if this ever causes issues.
        if field in self.extra_data_fields:
            return self.extra_data_fields[field]
        raise AttributeError(field)

    def fishers_dense_shape(self) -> Tuple[int, int]:
        return (self.fishers.shape[0], self.fisher_dense_size)

    def is_sparse(self) -> bool:
        return self.fisher_indices is not None

    def compute_sparse_hits_per_parameter(self) -> np.ndarray:
        if not self.is_sparse():
            raise ValueError('Must have sparse per-example Fishers.')

        counts = np.zeros([self.fisher_dense_size], dtype=np.int32)
        for indices in self.fisher_indices:
            counts[indices] += 1

        return counts

    def to_sparse_fishers(self) -> SparseTensorInfo:
        indices = _to_coo_indices(self.fisher_indices)
        values = self.fishers.reshape([-1])
        return indices, values, self.fishers_dense_shape()

    def to_reduced_fishers(self, threshold: int = 1) -> Tuple[np.ndarray, np.ndarray]:
        assert threshold == 1, 'TODO: Support higher thresholds.'
        # Reduced means that positions that are zero for all examples are removed.
        hit_counts = self.compute_sparse_hits_per_parameter()
        dense_keep_mask = hit_counts >= threshold
        reduction_info, = np.nonzero(dense_keep_mask)

        og_to_reduced_index = np.cumsum(dense_keep_mask.astype(np.int32)) - 1
        reduced_size = og_to_reduced_index[-1] + 1

        reduced = np.zeros([self.fishers.shape[0], reduced_size], dtype=self.fishers.dtype)
        for i in range(self.fishers.shape[0]):
            reduced_indices = og_to_reduced_index[self.fisher_indices[i]]
            reduced[i, reduced_indices] = self.fishers[i]

        return reduction_info, reduced

    def to_sparse_reduced_fishers(self, threshold: int = 1) -> Tuple[np.ndarray, SparseTensorInfo]:
        assert threshold == 1, 'TODO: Support higher thresholds.'
        # Reduced means that positions that are zero for all examples are removed.
        hit_counts = self.compute_sparse_hits_per_parameter()
        dense_keep_mask = hit_counts >= threshold
        reduction_info, = np.nonzero(dense_keep_mask)

        og_to_reduced_index = np.cumsum(dense_keep_mask.astype(np.int32)) - 1
        reduced_size = og_to_reduced_index[-1] + 1

        reduced_indices = np.zeros_like(self.fisher_indices)
        for i in range(self.fishers.shape[0]):
            reduced_indices[i, :] = og_to_reduced_index[self.fisher_indices[i]]

        indices = _to_coo_indices(reduced_indices)
        values = self.fishers.reshape([-1])
        dense_shape = [self.fishers.shape[0], reduced_size]

        return reduction_info, (indices, values, dense_shape)

    def create_for_subset(self, subset_inds: Sequence[int]) -> 'PerExampleFlatFishers':
        if not isinstance(subset_inds, np.ndarray):
            subset_inds = np.array(subset_inds, dtype=np.int32)

        def maybe_slice(a):
            return a if a is None else a[subset_inds]

        return PerExampleFlatFishers(
            input_ids=self.input_ids[subset_inds, :],
            fishers=self.fishers[subset_inds, :],
            dense_fisher_norms=self.dense_fisher_norms[subset_inds],
            labels=self.labels[subset_inds],
            predicted_logits=self.predicted_logits[subset_inds, :],
            fisher_dense_size=self.fisher_dense_size,
            fisher_indices=self.fisher_indices[subset_inds, :],
            sq_parameter_deltas=self.sq_parameter_deltas,
            dense_metric_derived_norms=maybe_slice(self.dense_metric_derived_norms),
            # NOTE: Not sure how I'm using these, so this might need to be changed in
            # the future.
            extra_data_fields={
                k: maybe_slice(v)
                for k, v in self.extra_data_fields.items()
            },
        )

    def get_full_examples(self, tokenizer: PreTrainedTokenizer, trim=False):
        input_ids = self.input_ids
        attention_mask = (input_ids != tokenizer.pad_token_id).astype(np.int32)

        sep_token = tokenizer.sep_token_id
        # TODO: There is probably a more efficient way to do this.
        token_type_ids = []
        for example_ids in input_ids:
            tt = np.zeros_like(example_ids)
            inds, = np.nonzero(example_ids == sep_token)
            if len(inds) in (1, 2):
                # Up to and including sep_token have type id 0, then type id 1.
                tt[inds[0] + 1:] = 1
            elif len(inds) != 0:
                raise ValueError
            token_type_ids.append(tt)

        token_type_ids = np.stack(token_type_ids, axis=0)

        examples = {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask,
        }

        if trim:
            mask = (attention_mask == 1).any(axis=0)
            examples = {
                k: v[:, mask]
                for k, v in examples.items()
            }

        return examples, self.labels

    def get_full_examples_ds(self, tokenizer: PreTrainedTokenizer):
        return tf.data.Dataset.from_tensor_slices(self.get_full_examples(tokenizer))

    def save(self, filepath: str):
        # TODO: This is just written to support the sparse_dynamic_raw type of pef; this
        # method doesn't fully support everything.
        with h5py.File(os.path.expanduser(filepath), "w") as f:
            _save_h5_ds(f, 'data/input_ids', self.input_ids)
            _save_h5_ds(f, 'data/dense_fisher_norms', self.dense_fisher_norms)

            _save_h5_ds(f, 'data/labels', self.labels)
            _save_h5_ds(f, 'data/predicted_logits', self.predicted_logits)

            _save_h5_ds(f, 'data/fisher/values', self.fishers)
            _save_h5_ds(f, 'data/fisher/indices', self.fisher_indices)
            f['data/fisher'].attrs['dense_fisher_size'] = self.fisher_dense_size

    @classmethod
    def _load_data_field(cls, file: h5py.File, field: str, n_examples: int):
        ds = file[f'data/{field}']
        data = np.zeros([n_examples, *ds.shape[1:]], dtype=ds.dtype)
        ds.read_direct(data, np.s_[:n_examples])
        return data    

    @classmethod
    def _load_extra_data_fields(cls, file: h5py.File, extra_data_fields: Sequence[str], n_examples: int):
        return {
            field: cls._load_data_field(file, field, n_examples)
            for field in extra_data_fields
        }

    @classmethod
    def load(
        cls,
        filepath: str,
        n_examples: Optional[int] = None,
        start_fisher_index: Optional[int] = None,
        end_fisher_index: Optional[int] = None,
        *,
        extra_data_fields: Optional[Sequence[str]] = None,
        normalize_fishers: bool = True,
    ) -> "PerExampleFlatFishers":
        # TODO: Make the interface for selecting slices of examples and fishers better.
        filepath = os.path.expanduser(filepath)

        if extra_data_fields is None:
            extra_data_fields = []

        # NOTE: This only supports staticly shaped fishers.
        with h5py.File(filepath, "r") as f:
            has_sparsely_stored_fishers = 'data/fisher/values' in f

            if has_sparsely_stored_fishers:
                fisher_size = f['data/fisher/values'].shape[1]
                if n_examples is None:
                    n_examples = f['data/fisher/values'].shape[0]

            else:
                fisher_size = f['data/fishers'].shape[1]
                if n_examples is None:
                    n_examples = f['data/fishers'].shape[0]

            if start_fisher_index is None: 
                start_fisher_index = 0
            elif start_fisher_index < 0:
                start_fisher_index = fisher_size + start_fisher_index

            if end_fisher_index is None: 
                end_fisher_index = fisher_size
            elif end_fisher_index < 0:
                end_fisher_index = fisher_size + end_fisher_index
        
            read_fisher_size = end_fisher_index - start_fisher_index

            if n_examples is None:
                n_examples = f['data/input_ids'].shape[0]

            if 'data/input_ids' in f:
                sequence_length = f['data/input_ids'].shape[1]
                input_ids = np.zeros([n_examples, sequence_length], dtype=np.int32)
                f['data/input_ids'].read_direct(input_ids, np.s_[:n_examples])
            else:
                input_ids = None

            if read_fisher_size == 0:
                # This let's us get away with not reading the fisher information at all, which
                # can save a significant amount of time.
                fishers = np.zeros([n_examples, 0], dtype=np.float32)

                # TODO: Refactor this section a bit to remove the code duplication.
                if has_sparsely_stored_fishers:
                    fisher_indices = np.zeros([n_examples, 0], dtype=np.int32)
                    fisher_dense_size = f['data/fisher'].attrs['dense_fisher_size']
                else:
                    fisher_indices = None
                    fisher_dense_size = f['data/fisher'].shape[-1]

            elif has_sparsely_stored_fishers:
                fishers = np.zeros([n_examples, read_fisher_size], dtype=np.float32)
                f['data/fisher/values'].read_direct(fishers, np.s_[:n_examples, start_fisher_index:end_fisher_index])
                
                fisher_indices = np.zeros([n_examples, read_fisher_size], dtype=np.int32)
                f['data/fisher/indices'].read_direct(fisher_indices, np.s_[:n_examples, start_fisher_index:end_fisher_index])
                fisher_dense_size = f['data/fisher'].attrs['dense_fisher_size']
            else:
                fishers = np.zeros([n_examples, read_fisher_size], dtype=np.float32)
                f['data/fishers'].read_direct(fishers, np.s_[:n_examples, start_fisher_index:end_fisher_index])
                fisher_indices = None
                fisher_dense_size = f['data/fisher'].shape[-1]

            dense_fisher_norms = np.zeros([n_examples], dtype=np.float32)
            f['data/dense_fisher_norms'].read_direct(dense_fisher_norms, np.s_[:n_examples])

            if 'data/dense_metric_derived_norms' in f:
                dense_metric_derived_norms = np.zeros([n_examples], dtype=np.float32)
                f['data/dense_metric_derived_norms'].read_direct(dense_metric_derived_norms, np.s_[:n_examples])
            else:
                dense_metric_derived_norms = None

            if normalize_fishers:
                # fishers /= np.sqrt(np.sum(fishers ** 2, axis=-1, keepdims=True) + 1e-12)
                # fishers /= dense_fisher_norms[:, None]
                fishers /= (dense_fisher_norms[:, None] + 1e-12)

            labels = np.zeros([n_examples], dtype=np.int32)
            f['data/labels'].read_direct(labels, np.s_[:n_examples])

            n_classes = f['data/predicted_logits'].shape[1]
            predicted_logits = np.zeros([n_examples, n_classes], dtype=np.float32)
            f['data/predicted_logits'].read_direct(predicted_logits, np.s_[:n_examples])

            if 'data/sq_parameter_deltas' in f:
                sq_parameter_deltas = np.zeros([fisher_dense_size], dtype=np.float32)
                f['data/sq_parameter_deltas'].read_direct(sq_parameter_deltas)

            else:
                sq_parameter_deltas = None

            loaded_extra_data_fields = cls._load_extra_data_fields(f, extra_data_fields, n_examples)

            return cls(
                input_ids=input_ids,
                fishers=fishers,
                dense_fisher_norms=dense_fisher_norms,
                labels=labels,
                predicted_logits=predicted_logits,
                fisher_indices=fisher_indices,
                fisher_dense_size=fisher_dense_size,
                sq_parameter_deltas=sq_parameter_deltas,
                dense_metric_derived_norms=dense_metric_derived_norms,
                extra_data_fields=loaded_extra_data_fields
            )
            

###############################################################################

def stream_per_example_diagonal_fishers(
    model,
    dataset: tf.data.Dataset,
    variables: Optional[Sequence[tf.Variable]] = None,
    *,
    expectation_wrt_logits=True,
    unbatch=True,
):
    if variables is None:
        variables = hf_util.get_mergeable_variables(model)
    for batch in dataset:
        if isinstance(batch, tuple):
            batch, _ = batch
        batch_fishers = diagonal.compute_exact_fisher_for_batch(
            batch, model, variables,
            expectation_wrt_logits=expectation_wrt_logits,
            per_example=True)
        if unbatch:
            batch_size = diagonal.batch_size_from_batch(batch)
            for i in range(batch_size):
                yield [f[i] for f in batch_fishers]
        else:
            yield batch_fishers


@tf.function
def _compute_exact_sparse_fisher_for_batch(batch, model, sparse_indices, variables, expectation_wrt_logits):
    assert len(sparse_indices) == len(variables)
    batch_fishers = diagonal.compute_exact_fisher_for_batch(
        batch, model, variables,
        expectation_wrt_logits=expectation_wrt_logits,
        per_example=True)
    batch_sparse_fishers = [
        tf.vectorized_map(functools.partial(tf.gather_nd, indices=inds), f)
        for f, inds in zip(batch_fishers, sparse_indices)
    ]
    return batch_sparse_fishers


def stream_per_example_sparse_diagonal_fishers(
    model,
    dataset: tf.data.Dataset,
    sparse_indices: Sequence[Union[tf.Tensor, tf.sparse.SparseTensor]],
    variables: Optional[Sequence[tf.Variable]] = None,
    *,
    expectation_wrt_logits=True,
    unbatch=True,
):
    if variables is None:
        variables = hf_util.get_mergeable_variables(model)

    assert len(sparse_indices) == len(variables)
    sparse_indices = [_get_sparse_indices(s) for s in sparse_indices]

    for batch in dataset:
        if isinstance(batch, tuple):
            batch, _ = batch
        batch_fishers = _compute_exact_sparse_fisher_for_batch(
            batch, model, sparse_indices, variables,
            expectation_wrt_logits=expectation_wrt_logits)
        if unbatch:
            batch_size = diagonal.batch_size_from_batch(batch)
            for i in range(batch_size):
                yield [f[i] for f in batch_fishers]
        else:
            yield batch_fishers

###############################################################################


def _get_sparse_indices(indices_or_sparse: Union[tf.Tensor, tf.sparse.SparseTensor]) -> tf.Tensor:
    if isinstance(indices_or_sparse, tf.sparse.SparseTensor):
        return indices_or_sparse.indices
    return indices_or_sparse
