import torch

from .base import Reward

import gcip.utils.io as pb_io


class RewardRatio(Reward):
    def __init__(self,
                 *args,
                 lambda_1=0.5,
                 lambda_2=0.5,
                 lambda_3=0.5,
                 k_1=0.5,
                 k_2=0.5,
                 k_3=0.0,
                 desired_ratio=1.0,
                 action_refers_to=None,
                 **kwargs):

        assert desired_ratio > 0.0
        assert action_refers_to in ['node', 'edge']

        for lamba_ in [lambda_1, lambda_2, lambda_3]:
            assert lamba_ >= 0.0
            assert lamba_ <= 1.0
        for k in [k_1, k_2, k_3]:
            assert k >= 0.0

        self.lambda_1 = lambda_1
        self.lambda_2 = lambda_2
        self.lambda_3 = lambda_3

        self.k_1 = k_1
        self.k_2 = k_2
        self.k_3 = k_3

        self.desired_ratio = desired_ratio
        self.action_refers_to = action_refers_to

        super(RewardRatio, self).__init__(*args, **kwargs)

    @staticmethod
    def kwargs(cfg, preparator, graph_clf):
        my_dict = {}

        if isinstance(cfg.reward.lambda_, float):
            my_dict['lambda_1'] = cfg.reward.lambda_
            my_dict['lambda_2'] = cfg.reward.lambda_
            my_dict['lambda_3'] = cfg.reward.lambda_
        else:
            my_dict['lambda_1'] = cfg.reward.lambda_1
            my_dict['lambda_2'] = cfg.reward.lambda_2
            my_dict['lambda_3'] = cfg.reward.lambda_3

        my_dict['k_1'] = cfg.reward.k_1
        my_dict['k_2'] = cfg.reward.k_2
        my_dict['k_3'] = cfg.reward.k_3

        my_dict['desired_ratio'] = cfg.reward.desired_ratio
        my_dict['action_refers_to'] = cfg.env.action_refers_to

        my_dict.update(Reward.kwargs(cfg, preparator, graph_clf))

        return my_dict

    def _reward_ratio(self, logits,
                      r_perf,
                      r_spar):
        '''
        Positive reward:  cte*desired_ratio_cte
        Negative reward:   -cte
        Neutral reward: 0.
        Args:
            logits: Logits of the clf with the sub-graph

        Returns:
            reward scalar
        '''

        # pb_io.print_debug_tensor(logits, "logits")

        pred_hard = self.logits_to_hard_pred(logits)
        # pb_io.print_debug_tensor(pred_hard, "pred_hard")
        # pb_io.print_debug_tensor(self.target, "target")

        k_1 = self.k_1 * torch.ones_like(pred_hard)
        k_2 = self.k_2 * torch.ones_like(pred_hard)
        k_3 = self.k_3 * torch.ones_like(pred_hard)

        lambda_1 = self.lambda_1 * torch.ones_like(pred_hard)
        lambda_2 = self.lambda_2 * torch.ones_like(pred_hard)
        lambda_3 = self.lambda_3 * torch.ones_like(pred_hard)

        sub_graph_is_correct = self._subgraph_is_correct(pred_hard)

        if self.full_graph_is_correct:  # Clf is correct!
            # print(f"FULL IS CORRECT")
            if sub_graph_is_correct:  # Sub-graph is correct!
                # print(f"SPARSE IS CORRECT")
                reward = self._reward_1(lambda_1, k_1, r_perf, r_spar)
                assert reward >= 0, f"reward: {reward} | {k_1} {r_spar}"
                return reward
            else:  # Sub-graph is wrong!
                # print(f"SPARSE IS WRONG")
                reward = - self._reward_2(lambda_2, k_2, r_perf, r_spar)
                # assert reward <= 0, f"reward: {reward}"
                return reward

        else:  # Clf is wrong!
            # print(f"FULL IS WRONG")
            if sub_graph_is_correct:  # Sub-graph is correct!
                # print(f"SPARSE IS CORRECT")
                reward = self._reward_3(lambda_3, k_3, r_perf, r_spar)
                assert reward >= 0, f"reward: {reward}"
                return reward
            else:
                # print(f"SPARSE IS WRONG")
                return 0.0 * torch.ones_like(pred_hard)

    def _reward_performance(self):
        return

    def _reward_sparsity(self):
        return

    def _reward_1(self, lambda_1, k_1, r_perf, r_spar):
        raise NotImplementedError

    def _reward_2(self, lambda_2, k_2, r_perf, r_spar):
        raise NotImplementedError

    def _reward_3(self, lambda_3, k_3, r_perf, r_spar):
        raise NotImplementedError

    def __str__(self):
        my_str = super(RewardRatio, self).__str__()
        my_str += f"\tlambda_1={self.lambda_1}\n"
        my_str += f"\tlambda_2={self.lambda_2}\n"
        my_str += f"\tlambda_3={self.lambda_3}\n"

        my_str += f"\tk_1={self.k_1}\n"
        my_str += f"\tk_2={self.k_2}\n"
        my_str += f"\tk_3={self.k_3}\n"

        my_str += f"\tdesired_ratio={self.desired_ratio}\n"

        return my_str
