from abc import ABC, abstractmethod
import torch


class Reward(ABC):
    def __init__(self,
                 graph_clf,
                 logits_to_soft_pred,
                 logits_to_hard_pred,
                 batch_to_target,
                 loss_fn,
                 device='cpu'):
        self.graph_clf = graph_clf

        if device == 'auto':
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device


        self.logits_to_soft_pred = logits_to_soft_pred
        self.logits_to_hard_pred = logits_to_hard_pred
        self.batch_to_target = batch_to_target

        self.loss_fn = loss_fn

        self.state_0 = None
        self.target = None
        self.logits_0 = None


    @staticmethod
    def kwargs(cfg, preparator, graph_clf):
        my_dict = {}
        my_dict['graph_clf'] = graph_clf

        my_dict['logits_to_soft_pred'] = preparator.get_output_act_fn()
        my_dict['logits_to_hard_pred'] = preparator.get_logits_to_hard_pred_fn()
        my_dict['batch_to_target'] = preparator.get_target

        my_dict['loss_fn'] = preparator.get_loss_fn()
        my_dict['device'] = cfg.device

        return my_dict
    @property
    def full_graph_is_correct(self):
        target_clf = self.logits_to_hard_pred(self.logits_0)
        return (self.target == target_clf).float().mean() > 0.5

    @abstractmethod
    def _compute(self, logits, state, **kwargs):
        pass

    def compute(self, state, **kwargs):
        logits = self.graph_clf(state.clone().to(self.device))
        logits = self._flatten(logits)
        reward = self._compute(logits, state=state.clone(), **kwargs)
        assert reward.ndim <=2
        if reward.ndim == 2:
            assert reward.shape[-1] == 1
            reward = reward.flatten()

        return  reward

    def _flatten(self, logits):
        # if len(logits.shape) == 2:
        #     assert logits.shape[0] == 1
        #     logits = logits.flatten()
        return logits

    def _num_samples(self, logits):
        if logits.ndim == 1:
            return 1
        else:
            return logits.shape[0]

    def _subgraph_is_correct(self, pred_hard):
        return (pred_hard == self.target).float().mean() > 0.5

    def set_state_0(self, state_0):
        self.state_0 = state_0
        self.target = self.batch_to_target(state_0)
        logits = self.graph_clf(state_0.clone().to(self.device))
        self.logits_0 = logits



    def __str__(self):
        my_str = f"{self.__class__}\n"
        return my_str
