from gfn_subtb_grid.metrics import BaseMetric
from typing import Dict

class Losses(BaseMetric):
    def __init__(self, config: Dict[str, object]):
        self.target_policy_losses = None
        self.behavior_policy_losses = None

        self.first_k_losses_mean_k_vals = config.get(
            'first_k_losses_mean_k_vals',
            []
        )

        self.last_k_losses_mean_k_vals = config.get(
            'last_k_losses_mean_k_vals',
            []
        )

    def update(self, update_infos: Dict[str, object]) -> None:
        if 'target_policy_losses' in update_infos:
            self.target_policy_losses = update_infos['target_policy_losses']
        if 'behavior_policy_losses' in update_infos:
            self.behavior_policy_losses = update_infos['behavior_policy_losses']

    def compute_for_step_result_dict(self) -> Dict[str, object]:
        assert (
            self.target_policy_losses is not None or
            self.behavior_policy_losses is not None
        )

        to_add = []
        if self.target_policy_losses is not None:
            to_add.append('target')
        if self.behavior_policy_losses is not None:
            to_add.append('behavior')

        result_dict = {}
        for policy_name in to_add:
            pol_losses = getattr(self, '%s_policy_losses' % policy_name)
            result_dict['%s_policy_loss_all_mean' % policy_name] = \
                pol_losses.mean().item()

            for k in self.first_k_losses_mean_k_vals:
                key = '%s_policy_loss_first_%d_mean' % (policy_name, k)
                result_dict[key] = pol_losses[:k].mean().item()

            for k in self.last_k_losses_mean_k_vals:
                key = '%s_policy_loss_last_%d_mean' % (policy_name, k)
                result_dict[key] = pol_losses[-k:].mean().item()

        return result_dict
