"""Fixed rank dense LRM-PEFs.

These LRM-PEFs can either be raw (i.e. the same size as the model parameters) or
be a random projection to a lower dimension.
"""
import dataclasses
import json
import os
from typing import Dict, List, Optional

import h5py
import numpy as np
import torch
from transformers import PreTrainedModel

from npeff_torch.models import parameter_infos
from npeff_torch.peis import position_selectors
from npeff_torch.peis import random_projectors
from npeff_torch.peis.fishers import class_selectors
from npeff_torch.peis.fishers.computers import pef_computer_common
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.util import hdf5_utils


###############################################################################
# Keys in examples that are typically used. Depending on the model/tokenizer, only
# a subset of these might be typically used.

_TYPICAL_EXAMPLE_KEYS = ('input_ids', 'attention_mask', 'token_type_ids')

###############################################################################


@dataclasses.dataclass
class DenseLrmPefs:
    pass
    # TODO

###############################################################################


@dataclasses.dataclass
class StreamingLrmPefSaver(pef_format_common.StreamingPefSaverAbc):
    """Saves LRM-PEFs to a file in a streaming manner."""

    model: PreTrainedModel

    fisher_computer: 'pef_computer_common.PefComputerAbc'

    position_selector: 'position_selectors.PositionSelectorAbc'
    class_subset_selector: 'class_selectors.ClassSubsetSelectorAbc'

    label_key: Optional[str]
    
    device: torch.device

    save_examples: bool
    save_labels: bool
    save_logits: bool

    parameter_infos: List['parameter_infos.ParameterInfo']

    random_projection_params: Optional['random_projectors.RandomProjectionParams']

    # Useful when used instead of save_logits for a large number classes, like what is
    # commonly seen for a language model.
    save_top_n_log_probs: Optional[int] = None

    # This only makes sense for lm_suffix_mc models, where each "example" is the set of
    # sequences consisting of a common prefix and the various suffixes to evaluate. If True, then only
    # the first of these sequences will be saved. If False, then all of the options will be saved.
    lm_suffix_mc_save_only_first_example: bool = False

    def __post_init__(self):
        if self.random_projection_params is not None:
            self._random_projector = random_projectors.RandomProjector(params=self.random_projection_params)
        else:
            self._random_projector = None

    def _initialize_file(self, file: h5py.File, n_examples: int):
        self._rank = None
        # This is the dimension of the LRM-PEFs as presented.
        self._n_parameters = None
        # This is the dimension of the LRM-PEFs from the original model. If no random projection is
        # done, this will equal self._n_parameters.
        self._n_og_parameters = None

        self._n_examples_processed = 0
        self._n_examples = n_examples

        self._data_grp = file.create_group('data')

        self._data_grp.attrs['pef_format'] = "frdn_lrm"
        self._data_grp.attrs['pef_format_version'] = "0.0.1"

        self._data_grp.attrs['parameter_infos'] = json.dumps([p.to_json() for p in self.parameter_infos])
        file['data'].attrs['n_og_parameters'] = int(self.fisher_computer.get_n_original_parameters())
        print('n_og_parameters:', int(self.fisher_computer.get_n_original_parameters()))

        if self.random_projection_params is not None:
            self._data_grp.attrs['random_projection_params'] = json.dumps(self.random_projection_params.to_json())

    def _finalize_file(self, file: h5py.File):
        file['data'].attrs['rank'] = int(self._rank)
        # Write also to n_classes for back-compatability purposes.
        file['data'].attrs['n_classes'] = int(self._rank)
        file['data'].attrs['n_parameters'] = int(self._n_parameters)

    #######################################################

    def _write_and_make_if_needed(self, ds_name: str, x: torch.Tensor, dtype=None):
        x = x.detach().cpu().numpy()
        if ds_name not in self._data_grp:
            self._data_grp.create_dataset(ds_name, [self._n_examples, *x.shape], dtype=dtype or x.dtype)

        h5_ds = self._data_grp[ds_name]
        h5_ds[self._n_examples_processed] = x.astype(h5_ds.dtype)
        
    def _write_top_log_probs(self, log_probs: torch.Tensor):
        top_inds = torch.argsort(log_probs, descending=True)[:self.save_top_n_log_probs]
        top_values = log_probs[top_inds]
        self._write_and_make_if_needed('top_log_probs/class_indices', top_inds, np.int32)
        self._write_and_make_if_needed('top_log_probs/values', top_values, np.float32)
        
    def _make_examples_to_save(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        if not self.lm_suffix_mc_save_only_first_example:
            return batch

        ret = batch.copy()

        for k, v in batch.items():
            if k not in _TYPICAL_EXAMPLE_KEYS:
                continue
            # shape = [batch (=1), options, sequence]
            assert len(v.shape) == 3
            ret[k] = v[:, 0, :]

        return ret

    #######################################################

    def _process_batch(self, batch: Dict[str, torch.Tensor]) -> bool:
        dense_pef_info = self.compute_dense_pef_info(batch)
        if dense_pef_info is None:
            return False

        dense_pef = dense_pef_info.dense_pef

        # Project if desired.
        if self._random_projector is not None:
            dense_pef = self._random_projector.project(dense_pef)

        # NOTE: This is computed for the projected PEFs if a random projection is used.
        frobenius_norm = pef_computer_common.compute_lrm_pef_frobenius_norm(dense_pef)

        # These should be the same for all examples.
        self._rank, self._n_parameters = dense_pef.shape

        self._write_and_make_if_needed('pefs', dense_pef, np.float32)
        self._write_and_make_if_needed('pef_frobenius_norms', frobenius_norm, np.float32)

        if dense_pef_info.position is not None:
            self._write_and_make_if_needed('token_positions', dense_pef_info.position, np.int32)

        if self.save_examples:
            # TODO: Find a better way of filtering the items in the batch dict.
            for k, v in self._make_examples_to_save(batch).items():
                if k not in _TYPICAL_EXAMPLE_KEYS:
                    continue
                v = torch.squeeze(v, dim=0)
                self._write_and_make_if_needed(f'examples/{k}', v)

        if self.save_labels and dense_pef_info.label is not None:
            self._write_and_make_if_needed('labels', dense_pef_info.label, np.int32)

        if self.save_logits:
            self._write_and_make_if_needed('logits', dense_pef_info.log_probs, np.float32)

        if self.save_top_n_log_probs:
            self._write_top_log_probs(dense_pef_info.log_probs)

        self._n_examples_processed += 1

        return True


###############################################################################


def read_random_projector_params_from_file(filepath: str) -> 'random_projectors.RandomProjectionParams':
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        return random_projectors.RandomProjectionParams.from_json(json.loads(f['data'].attrs['random_projection_params']))


def read_n_og_parameters_from_file(filepath: str) -> int:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        return f['data'].attrs['n_og_parameters']


###############################################################################


def load_pefs(filepath: str) -> torch.Tensor:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        return torch.from_numpy(hdf5_utils.load_h5_ds(f['data/pefs']))


def load_pef_frobenius_norms(filepath: str) -> torch.Tensor:
    with h5py.File(os.path.expanduser(filepath), "r") as f:
        return torch.from_numpy(hdf5_utils.load_h5_ds(f['data/pef_frobenius_norms']))
