import random

import torch

import gcip.utils.io as pb_io
from gcip.utils.graph import remove_nodes_from_batch
from .ratio import RewardRatio
import numpy as np

class RewardGCIP(RewardRatio):
    def __init__(self, entropy, *args, **kwargs):
        self.entropy = entropy
        super(RewardGCIP, self).__init__(*args, **kwargs)

        self.exp_cte = np.log(1.0 - 0.95) / np.log(self.desired_ratio)

    @staticmethod
    def kwargs(cfg, preparator, graph_clf):
        my_dict = {}
        my_dict['entropy'] = preparator.get_entropy()
        my_dict.update(RewardRatio.kwargs(cfg, preparator, graph_clf))

        return my_dict

    def _get_state_rnd(self, num_nodes):

        num_nodes_0 = self.state_0.x.shape[0]
        if num_nodes_0 == num_nodes:
            return self.state_0.clone()
        num_nodes_to_remove = num_nodes_0 - num_nodes
        nodes_to_remove_rnd = random.sample(list(range(num_nodes_0)), num_nodes_to_remove)
        graph_rnd = remove_nodes_from_batch(batch=self.state_0.clone(),
                                            nodes_idx=nodes_to_remove_rnd,
                                            relabel_nodes=True,
                                            has_batch_att=True)

        return graph_rnd

    def _compute(self, logits, state, **kwargs):

        num_samples = self._num_samples(logits)


        if self.action_refers_to == 'node':
            ratio = state.x.shape[0] / self.state_0.x.shape[0]
        elif self.action_refers_to == 'edge':
            ratio = state.edge_index.shape[1] / self.state_0.edge_index.shape[1]

        desired_ratio_cte = 1.0 - (ratio ** self.exp_cte)

        desired_ratio_cte = desired_ratio_cte ** 2 * torch.ones(num_samples).to(logits.device)
        entropy = self.entropy(logits=logits, normalize=True)
        # entropy_rnd = self.entropy(logits=logits_rnd, normalize=True)
        #
        # entropy_diff = entropy - entropy_rnd
        reward_performance = (1 - entropy)
        reward_sparsity = desired_ratio_cte

        return self._reward_ratio(logits=logits,
                                  r_perf=reward_performance,
                                  r_spar=reward_sparsity)

    def _reward_1(self, lambda_1, k_1, r_perf, r_spar):
        reward_1 = k_1 * (lambda_1 * r_perf + (1 - lambda_1) * r_spar)
        return reward_1

    def _reward_2(self, lambda_2, k_2, r_perf, r_spar):
        reward_2 = k_2 * (lambda_2 * r_perf + (1 - lambda_2) * r_spar)
        return reward_2

    def _reward_3(self, lambda_3, k_3, r_perf, r_spar):
        reward_3 = k_3 * (lambda_3 * r_perf + (1 - lambda_3) * r_spar)
        return reward_3
