"""Dense gradients.

These gradients can either be raw (i.e. the same size as the model parameters) or
be a random projection to a lower dimension.
"""
import dataclasses
import json
import os
from typing import Dict, List, Optional

import h5py
import numpy as np
import torch
from transformers import PreTrainedModel

from npeff_torch.models import parameter_infos
from npeff_torch.peis import position_selectors
from npeff_torch.peis import random_projectors

from npeff_torch.peis.gradients import gradient_computers
from npeff_torch.peis.gradients import logit_functions
from npeff_torch.peis.gradients.formats import gradient_format_common

from npeff_torch.util import hdf5_utils


###############################################################################
# Keys in examples that are typically used. Depending on the model/tokenizer, only
# a subset of these might be typically used.

_TYPICAL_EXAMPLE_KEYS = ('input_ids', 'attention_mask', 'token_type_ids')

###############################################################################


@dataclasses.dataclass
class StreamingGradientSaver(gradient_format_common.StreamingGradientSaverAbc):
    """Saves gradients to a file in a streaming manner."""

    model: PreTrainedModel

    gradient_computer: 'gradient_computers.GradientComputer'
    position_selector: 'position_selectors.PositionSelectorAbc'

    label_key: Optional[str]

    device: torch.device

    save_examples: bool
    save_labels: bool
    save_logits: bool

    parameter_infos: List['parameter_infos.ParameterInfo']

    random_projection_params: Optional['random_projectors.RandomProjectionParams']

    # Useful when used instead of save_logits for a large number classes, like what is
    # commonly seen for a language model.
    save_top_n_log_probs: Optional[int] = None

    # This only makes sense for lm_suffix_mc models, where each "example" is the set of
    # sequences consisting of a common prefix and the various suffixes to evaluate. If True, then only
    # the first of these sequences will be saved. If False, then all of the options will be saved.
    lm_suffix_mc_save_only_first_example: bool = False

    def __post_init__(self):
        if self.random_projection_params is not None:
            self._random_projector = random_projectors.RandomProjector(params=self.random_projection_params)
        else:
            self._random_projector = None

    def _initialize_file(self, file: h5py.File, n_examples: int):
        # This is the dimension of the gradients as presented.
        self._n_parameters = None
        # This is the dimension of the gradients from the original model. If no random projection is
        # done, this will equal self._n_parameters.
        self._n_og_parameters = None

        self._n_examples_processed = 0
        self._n_examples = n_examples

        self._data_grp = file.create_group('data')

        self._data_grp.attrs['gradient_format'] = "dn_gradients"
        self._data_grp.attrs['gradient_format_version'] = "0.0.1"

        self._data_grp.attrs['parameter_infos'] = json.dumps([p.to_json() for p in self.parameter_infos])

        if self.random_projection_params is not None:
            self._data_grp.attrs['random_projection_params'] = json.dumps(self.random_projection_params.to_json())
        else:
            self._data_grp.attrs['random_projection_params'] = None

    def _finalize_file(self, file: h5py.File):
        file['data'].attrs['n_parameters'] = int(self._n_parameters)
        file['data'].attrs['n_og_parameters'] = int(self._n_og_parameters)

    #######################################################

    def _write_and_make_if_needed(self, ds_name: str, x: torch.Tensor, dtype=None):
        x = x.detach().cpu().numpy()
        if ds_name not in self._data_grp:
            self._data_grp.create_dataset(ds_name, [self._n_examples, *x.shape], dtype=dtype or x.dtype)

        h5_ds = self._data_grp[ds_name]
        h5_ds[self._n_examples_processed] = x.astype(h5_ds.dtype)
        
    def _write_top_log_probs(self, log_probs: torch.Tensor):
        top_inds = torch.argsort(log_probs, descending=True)[:self.save_top_n_log_probs]
        top_values = log_probs[top_inds]
        self._write_and_make_if_needed('top_log_probs/class_indices', top_inds, np.int32)
        self._write_and_make_if_needed('top_log_probs/values', top_values, np.float32)

    def _make_examples_to_save(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        if not self.lm_suffix_mc_save_only_first_example:
            return batch

        ret = batch.copy()

        for k, v in batch.items():
            if k not in _TYPICAL_EXAMPLE_KEYS:
                continue
            # shape = [batch (=1), options, sequence]
            assert len(v.shape) == 3
            ret[k] = v[:, 0, :]

        return ret

    #######################################################

    def _process_batch(self, batch: Dict[str, torch.Tensor]) -> bool:
        dense_gradient_info = self.compute_dense_gradient_info(batch)
        
        if dense_gradient_info is None:
            return False

        dense_gradient = dense_gradient_info.dense_gradient

        # This should be the same for all examples.
        self._n_og_parameters, = dense_gradient.shape

        # Project if desired.
        if self._random_projector is not None:
            dense_gradient = self._random_projector.project(dense_gradient)

        # NOTE: This is computed for the projected gradient if a random projection is used.
        norm = torch.linalg.vector_norm(dense_gradient)

        # This should be the same for all examples.
        self._n_parameters, = dense_gradient.shape

        self._write_and_make_if_needed('gradients', dense_gradient, np.float32)
        self._write_and_make_if_needed('norms', norm, np.float32)
        self._write_and_make_if_needed('fn_values', dense_gradient_info.fn_value, np.float32)

        if dense_gradient_info.position is not None:
            self._write_and_make_if_needed('token_positions', dense_gradient_info.position, np.int32)

        if self.save_examples:
            # TODO: Find a better way of filtering the items in the batch dict.
            for k, v in self._make_examples_to_save(batch).items():
                if k not in _TYPICAL_EXAMPLE_KEYS:
                    continue
                v = torch.squeeze(v, dim=0)
                self._write_and_make_if_needed(f'examples/{k}', v)

        if self.save_labels and dense_gradient_info.label is not None:
            self._write_and_make_if_needed('labels', dense_gradient_info.label, np.int32)

        if self.save_logits:
            self._write_and_make_if_needed('logits', dense_gradient_info.log_probs, np.float32)

        if self.save_top_n_log_probs:
            self._write_top_log_probs(dense_gradient_info.log_probs)

        self._n_examples_processed += 1

        return True


###############################################################################


def load_gradients(filepath: str) -> torch.Tensor:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        return torch.from_numpy(hdf5_utils.load_h5_ds(f['data/gradients']))


def load_norms(filepath: str) -> torch.Tensor:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        return torch.from_numpy(hdf5_utils.load_h5_ds(f['data/norms']))
