"""Saving and loading my Protein-BERT checkpoints.

The checkpoints are saved in an hdf5 file.

body/weights
classifier/weights


"""
import dataclasses
import json
import os
from typing import List, Optional

import h5py
import numpy as np

from em.util import hdf5_util

from . import pb_models


@dataclasses.dataclass
class CheckpointData:
    body_config: pb_models.BodyConfig
    body_weights: List[np.ndarray]

    classifier_config: Optional[pb_models.ClassifierConfig] = None
    classifier_weights: Optional[List[np.ndarray]] = None

    def __post_init__(self):
        if (self.classifier_config is None) != (self.classifier_weights is None):
            raise ValueError

    def includes_classifier(self) -> bool:
        return self.classifier_config is not None

    @classmethod
    def from_model(cls, model: pb_models.ProteinBertForSequenceClassification):
        pass

    def _save(self, f: h5py.File, name: str, config, weights: List[np.ndarray]):
        group = f.create_group(name)
        group.attrs['config'] = json.dumps(config.to_dict())
        weights_group = group.create_group('weights')
        hdf5_util.save_np_arrays_to_group(weights_group, weights)

    def save(self, filepath: str):
        filepath = os.path.expanduser(filepath)
        with h5py.File(filepath, "w") as f:
            self._save(f, 'body', self.body_config, self.body_weights)
            if self.includes_classifier():
                self._save(f, 'classifier', self.classifier_config, self.classifier_weights)

    @classmethod
    def _load_kwargs(cls, f: h5py.File, name: str, ConfigCls):
        return {
            f'{name}_config': ConfigCls.from_dict(json.loads(f[name].attrs['config'])),
            f'{name}_weights': hdf5_util.load_np_arrays_from_group(f[f'{name}/weights']),
        }
        pass

    @classmethod
    def load(cls, filepath: str):
        kwargs = {}
        filepath = os.path.expanduser(filepath)
        with h5py.File(filepath, "r") as f:
            kwargs.update(cls._load_kwargs(f, 'body', pb_models.BodyConfig))
            if 'classifier' in f:
                kwargs.update(cls._load_kwargs(f, 'classifier', pb_models.ClassifierConfig))
        return cls(**kwargs)
