"""Fixed rank, fixed nnz LRM-pefs."""
import dataclasses
import json
import os
from typing import Dict, List, Optional

import h5py
import numpy as np
import torch
from transformers import PreTrainedModel

from npeff_torch.models import parameter_infos
from npeff_torch.peis import position_selectors
from npeff_torch.peis import sparsifiers
from npeff_torch.peis.fishers import class_selectors
from npeff_torch.peis.fishers.computers import pef_computer_common
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils


###############################################################################
# Keys in examples that are typically used. Depending on the model/tokenizer, only
# a subset of these might be typically used.

_TYPICAL_EXAMPLE_KEYS = ('input_ids', 'attention_mask', 'token_type_ids')

###############################################################################


@dataclasses.dataclass
class SparseLrmPefs:
    # Information about the parameters used to compute the Fishers.
    parameter_infos: List['parameter_infos.ParameterInfo']

    # values.shape = [n_examples, nnz_pef_example]
    values: np.ndarray
    # col_offsets.shape = [n_examples, rank + 1]
    col_offsets: np.ndarray
    # row_indices.shape = [n_examples, nnz_pef_example]
    row_indices: np.ndarray

    # Frobenius norms of each PEF.
    # norms.shape = [n_examples]
    norms: np.ndarray

    #######################################################
    # Optional information.

    token_positions: Optional[np.ndarray] = None

    examples: Optional[Dict[str, np.ndarray]] = None
    labels: Optional[np.ndarray] = None
    logits: Optional[np.ndarray] = None

    #######################################################

    def __post_init__(self):
        self.n_parameters = sum(p.n_elements() for p in self.parameter_infos)

    #######################################################

    @classmethod
    def load(cls, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            param_infos = [parameter_infos.ParameterInfo.from_json(pi) for pi in json.loads(f['data'].attrs['parameter_infos'])]

            if 'data/examples' in f:
                examples = {}
                for k, ds in f['data/examples'].items():
                    examples[k] = hdf5_utils.load_h5_ds(ds)
            else:
                examples = None
            
            return cls(
                parameter_infos=param_infos,
                #
                values=hdf5_utils.load_h5_ds(f['data/values']),
                col_offsets=hdf5_utils.load_h5_ds(f['data/col_offsets']),
                row_indices=hdf5_utils.load_h5_ds(f['data/row_indices']),
                norms=hdf5_utils.load_h5_ds(f['data/pef_frobenius_norms']),
                #
                token_positions=hdf5_utils.load_h5_ds_if_exists(f, 'data/token_positions'),
                labels=hdf5_utils.load_h5_ds_if_exists(f, 'data/labels'),
                logits=hdf5_utils.load_h5_ds_if_exists(f, 'data/logits'),
                #
                examples=examples,
            )


###############################################################################


@dataclasses.dataclass
class StreamingLrmPefSaver(pef_format_common.StreamingPefSaverAbc):
    """Saves LRM-PEFs to a file in a streaming manner."""

    model: PreTrainedModel

    fisher_computer: 'pef_computer_common.PefComputerAbc'

    nnz_per_example: int

    position_selector: 'position_selectors.PositionSelectorAbc'
    class_subset_selector: 'class_selectors.ClassSubsetSelectorAbc'

    label_key: Optional[str]
    
    device: torch.device

    save_examples: bool
    save_labels: bool
    save_logits: bool

    parameter_infos: List['parameter_infos.ParameterInfo']

    def __post_init__(self):
        self._sparsifier = sparsifiers.SingleLrmPefSparsifierCsc()

    def _initialize_file(self, file: h5py.File, n_examples: int):
        # NOTE: Not everything will be initialized since some depend on what we are computing
        # and the dtypes of some of the stuff we are computing.
        self._file = file

        self._rank = None
        self._n_parameters = None

        self._n_examples_processed = 0
        self._n_examples = n_examples

        self._data_grp = file.create_group('data')

        self._data_grp.attrs['pef_format'] = "frfn_lrm"
        self._data_grp.attrs['pef_format_version'] = "0.0.1"

        self._data_grp.attrs['parameter_infos'] = json.dumps([p.to_json() for p in self.parameter_infos])

        flt_ds = lambda n, s: self._data_grp.create_dataset(n, s, dtype=np.float32)
        # int_ds = lambda n, s: self._data_grp.create_dataset(n, s, dtype=np.int32)

        self._values_ds = flt_ds('values', [n_examples, self.nnz_per_example])

        self._pef_frobenius_norms_ds = flt_ds('pef_frobenius_norms', [n_examples])

    def _finalize_file(self, file: h5py.File):
        file['data'].attrs['rank'] = int(self._rank)
        # Write also to n_classes for back-compatability purposes.
        file['data'].attrs['n_classes'] = int(self._rank)
        file['data'].attrs['n_parameters'] = int(self._n_parameters)

    #######################################################

    def _write_and_make_if_needed(self, ds_name: str, x: torch.Tensor, dtype=None):
        x = x.detach().cpu().numpy()
        if ds_name not in self._data_grp:
            self._data_grp.create_dataset(ds_name, [self._n_examples, *x.shape], dtype=dtype or x.dtype)

        h5_ds = self._data_grp[ds_name]
        h5_ds[self._n_examples_processed] = x.astype(h5_ds.dtype)
        
    #######################################################

    def _process_batch(self, batch: Dict[str, torch.Tensor]) -> bool:
        dense_pef_info = self.compute_dense_pef_info(batch)
        if dense_pef_info is None:
            return False

        # TODO: Compute Frobenius norm of the dense pef.
        frobenius_norm = pef_computer_common.compute_lrm_pef_frobenius_norm(dense_pef_info.dense_pef)

        pef_csc = self._sparsifier.sparsify(dense_pef_info.dense_pef, self.nnz_per_example)

        # These should be the same for all examples.
        self._rank, self._n_parameters = pef_csc.shape

        self._write_and_make_if_needed('values', pef_csc.values, np.float32)
        self._write_and_make_if_needed('col_offsets', pef_csc.col_offsets)
        self._write_and_make_if_needed('row_indices', pef_csc.row_indices)

        self._write_and_make_if_needed('pef_frobenius_norms', frobenius_norm, np.float32)

        if dense_pef_info.position is not None:
            self._write_and_make_if_needed('token_positions', dense_pef_info.position, np.int32)

        if self.save_examples:
            # TODO: Find a better way of filtering the items in the batch dict.
            for k, v in batch.items():
                if k not in _TYPICAL_EXAMPLE_KEYS:
                    continue
                v = torch.squeeze(v, dim=0)
                self._write_and_make_if_needed(f'examples/{k}', v)

        if self.save_labels and dense_pef_info.label is not None:
            self._write_and_make_if_needed('labels', dense_pef_info.label, np.int32)

        if self.save_logits:
            self._write_and_make_if_needed('logits', dense_pef_info.log_probs, np.float32)

        self._n_examples_processed += 1
        return True
