"""Results of collateral damage assessment."""
import dataclasses
import json
import os
from typing import Dict, List

import h5py
import torch

from npeff_torch.util import hdf5_utils

###############################################################################


@dataclasses.dataclass
class EvaluationKls:
    # These are expected to the examples for which we computed PEFs for and have their NPEFF
    # component coefficients somewhere.
    # shape = [n_eval_examples], dtype=torch.int64
    example_indices: torch.Tensor
    
    # shape = [n_eval_examples], dtype=torch.float32
    kls: torch.Tensor

    #######################################################

    @property
    def n_examples(self) -> int:
        return self.example_indices.shape[0]

    #######################################################

    def to(self, device: torch.device) -> 'EvaluationKls':
        self.example_indices = self.example_indices.to(device)
        self.kls = self.kls.to(device)
        return self


###############################################################################


@dataclasses.dataclass
class CdaRunResults:

    forget_set_kls: 'EvaluationKls'

    evaluation_set_kls: 'EvaluationKls'

    unlearning_training_infos: Dict[str, torch.Tensor]

    #######################################################

    @property
    def forget_set_size(self) -> int:
        return self.forget_set_kls.n_examples

    #######################################################

    def to(self, device: torch.device) -> 'CdaRunResults':
        self.forget_set_kls = self.forget_set_kls.to(device)
        self.evaluation_set_kls = self.evaluation_set_kls.to(device)
        self.unlearning_training_infos = {k: v.to(device) for k, v in self.unlearning_training_infos.items()}
        return self

    #######################################################

    def save_to_group(self, group: h5py.Group):
        hdf5_utils.save_h5_ds(group, 'forget_set_kls/example_indices', self.forget_set_kls.example_indices.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(group, 'forget_set_kls/kls', self.forget_set_kls.kls.detach().cpu().numpy())

        hdf5_utils.save_h5_ds(group, 'evaluation_set_kls/example_indices', self.evaluation_set_kls.example_indices.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(group, 'evaluation_set_kls/kls', self.evaluation_set_kls.kls.detach().cpu().numpy())

        for k, v in self.unlearning_training_infos.items():
            hdf5_utils.save_h5_ds(group, f'unlearning_training_infos/{k}', v.detach().cpu().numpy())

        if 'unlearning_training_infos' in group:
            group['unlearning_training_infos'].attrs['keys'] = json.dumps(list(self.unlearning_training_infos.keys()))

    @classmethod
    def load_from_group(cls, group: h5py.Group) -> 'CdaRunResults':
        forget_set_kls = EvaluationKls(
            example_indices=hdf5_utils.load_h5_ds_as_tensor(group['forget_set_kls/example_indices']),
            kls=hdf5_utils.load_h5_ds_as_tensor(group['forget_set_kls/kls']),

        )
        evaluation_set_kls = EvaluationKls(
            example_indices=hdf5_utils.load_h5_ds_as_tensor(group['evaluation_set_kls/example_indices']),
            kls=hdf5_utils.load_h5_ds_as_tensor(group['evaluation_set_kls/kls']),
        )

        unlearning_training_infos = {}
        if 'unlearning_training_infos' in group:
            for key in json.loads(group['unlearning_training_infos'].attrs['keys']):
                unlearning_training_infos[key] = hdf5_utils.load_h5_ds_as_tensor(group[f'unlearning_training_infos/{key}'])

        return cls(
            forget_set_kls=forget_set_kls,
            evaluation_set_kls=evaluation_set_kls,
            unlearning_training_infos=unlearning_training_infos,
        )

    @classmethod
    def load_results_from_file(cls, filepath: str) -> List['CdaRunResults']:
        """Loads a list of CdaRunResults from a file.

        Right now, supports files created by the scripts:
            - run_gradient_ascent_single_example_cda.py
        """
        ret = []
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            for s_example_index in f['data/examples'].keys():
                for s_trial_index in f[f'data/examples/{s_example_index}/trial'].keys():
                    group = f[f'data/examples/{s_example_index}/trial/{s_trial_index}']
                    ret.append(cls.load_from_group(group))
        return ret

###############################################################################


@dataclasses.dataclass
class SpecialRetainCdaRunResults:

    forget_set_kls: 'EvaluationKls'

    special_retain_set_kls: 'EvaluationKls'
    special_evaluation_set_kls: 'EvaluationKls'

    evaluation_set_kls: 'EvaluationKls'

    unlearning_training_infos: Dict[str, torch.Tensor]

    #######################################################

    @property
    def forget_set_size(self) -> int:
        return self.forget_set_kls.n_examples

    #######################################################

    def to(self, device: torch.device) -> 'CdaRunResults':
        self.forget_set_kls = self.forget_set_kls.to(device)
        self.special_retain_set_kls = self.special_retain_set_kls.to(device)
        self.special_evaluation_set_kls = self.special_evaluation_set_kls.to(device)
        self.evaluation_set_kls = self.evaluation_set_kls.to(device)
        self.unlearning_training_infos = {k: v.to(device) for k, v in self.unlearning_training_infos.items()}
        return self

    #######################################################

    def save_to_group(self, group: h5py.Group):
        hdf5_utils.save_h5_ds(group, 'forget_set_kls/example_indices', self.forget_set_kls.example_indices.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(group, 'forget_set_kls/kls', self.forget_set_kls.kls.detach().cpu().numpy())

        hdf5_utils.save_h5_ds(group, 'special_retain_set_kls/example_indices', self.special_retain_set_kls.example_indices.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(group, 'special_retain_set_kls/kls', self.special_retain_set_kls.kls.detach().cpu().numpy())

        hdf5_utils.save_h5_ds(group, 'special_evaluation_set_kls/example_indices', self.special_evaluation_set_kls.example_indices.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(group, 'special_evaluation_set_kls/kls', self.special_evaluation_set_kls.kls.detach().cpu().numpy())

        hdf5_utils.save_h5_ds(group, 'evaluation_set_kls/example_indices', self.evaluation_set_kls.example_indices.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(group, 'evaluation_set_kls/kls', self.evaluation_set_kls.kls.detach().cpu().numpy())

        for k, v in self.unlearning_training_infos.items():
            hdf5_utils.save_h5_ds(group, f'unlearning_training_infos/{k}', v.detach().cpu().numpy())

        if 'unlearning_training_infos' in group:
            group['unlearning_training_infos'].attrs['keys'] = json.dumps(list(self.unlearning_training_infos.keys()))

    @classmethod
    def load_from_group(cls, group: h5py.Group) -> 'CdaRunResults':
        forget_set_kls = EvaluationKls(
            example_indices=hdf5_utils.load_h5_ds_as_tensor(group['forget_set_kls/example_indices']),
            kls=hdf5_utils.load_h5_ds_as_tensor(group['forget_set_kls/kls']),

        )
        special_retain_set_kls = EvaluationKls(
            example_indices=hdf5_utils.load_h5_ds_as_tensor(group['special_retain_set_kls/example_indices']),
            kls=hdf5_utils.load_h5_ds_as_tensor(group['special_retain_set_kls/kls']),

        )
        special_evaluation_set_kls = EvaluationKls(
            example_indices=hdf5_utils.load_h5_ds_as_tensor(group['special_evaluation_set_kls/example_indices']),
            kls=hdf5_utils.load_h5_ds_as_tensor(group['special_evaluation_set_kls/kls']),

        )
        evaluation_set_kls = EvaluationKls(
            example_indices=hdf5_utils.load_h5_ds_as_tensor(group['evaluation_set_kls/example_indices']),
            kls=hdf5_utils.load_h5_ds_as_tensor(group['evaluation_set_kls/kls']),
        )

        unlearning_training_infos = {}
        if 'unlearning_training_infos' in group:
            for key in json.loads(group['unlearning_training_infos'].attrs['keys']):
                unlearning_training_infos[key] = hdf5_utils.load_h5_ds_as_tensor(group[f'unlearning_training_infos/{key}'])

        return cls(
            forget_set_kls=forget_set_kls,
            special_retain_set_kls=special_retain_set_kls,
            special_evaluation_set_kls=special_evaluation_set_kls,
            evaluation_set_kls=evaluation_set_kls,
            unlearning_training_infos=unlearning_training_infos,
        )

    @classmethod
    def load_results_from_file(cls, filepath: str) -> List['CdaRunResults']:
        """Loads a list of CdaRunResults from a file.

        Right now, supports files created by the scripts:
            - run_single_example_special_retain.py
        """
        ret = []
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            for s_example_index in f['data/examples'].keys():
                for s_trial_index in f[f'data/examples/{s_example_index}/trial'].keys():
                    group = f[f'data/examples/{s_example_index}/trial/{s_trial_index}']
                    ret.append(cls.load_from_group(group))
        return ret
