import torch

from gcip.utils.graph import remove_edges_from_batch, remove_nodes_from_batch
from .node_env_base import GraphEnvBase


class GraphEnvOne(GraphEnvBase):
    """Removes one node at a time"""

    def __init__(self, *args, **kwargs):
        '''

        Args:
            loader:
            graph_clf:
            reward_fn:
            desired_perc_nodes:
            max_episode_length: This is positive integer that controls the maximum number of steps an episode can have
            use_intrinsic_reward: This is a boolean that controls if we use intermediate (intrinsic) rewards  or not
            device:
        '''
        super(GraphEnvOne, self).__init__(*args, **kwargs)

    @staticmethod
    def kwargs(cfg, preparator, graph_clf, reward):
        my_dict = {}
        my_dict.update(GraphEnvBase.kwargs(cfg, preparator, graph_clf, reward))
        return my_dict

    def _remove(self, state, idx_to_remove, relabel_nodes):
        if self.action_refers_to == 'edge':
            data = remove_edges_from_batch(edges_to_remove=[idx_to_remove],
                                           batch=state.clone(),
                                           relabel_nodes=relabel_nodes)
            return data

        elif self.action_refers_to == 'node':
            data = remove_nodes_from_batch(batch=state.clone(),
                                    nodes_idx=[idx_to_remove],
                                    relabel_nodes=relabel_nodes)

            return data
        else:
            raise NotImplementedError

    def _step(self, action, relabel_nodes=True):
        # Execute one time step within the environment

        idx_to_remove = torch.argmax(action).item()
        max_action_value = torch.max(action).item()
        # print(f"max_action_value: {max_action_value} {action}")
        # print(f"node_to_remove: {node_to_remove}")
        info = {'Elements to remove': 1,
                'Finish?': 'No',
                'current_iter': self.current_iter}

        # print(f"{self.current_iter} {num_nodes} {node_to_remove}")
        if max_action_value > 0.5:  # We need to remove some nodes
            data = self._remove(state=self.state,
                                idx_to_remove=idx_to_remove,
                                relabel_nodes=relabel_nodes)
            # print(f"data: {data}")
            if self._is_empty(data):
                done = True
                # print('YES')
                info['Finish?'] = 'Yes, result would have no nodes/edges'
                new_state = self.state

                reward = self.reward.compute(state=self.state, action=action)
                reward -= self.penalty_size * torch.ones(1, device=self.device)

                return new_state, reward, done, info
            else:
                self.state = data

        reward = self.reward.compute(state=self.state, action=action)

        intrinsic_reward = torch.zeros(1, device=self.device)

        self.current_reward = reward  # - self.init_reward

        if (max_action_value <= 0.5):
            done = True
            info['Finish?'] = 'Yes, No elements to remove'
            return self.state, self.current_reward, done, info
        else:
            done = False
            return self.state, intrinsic_reward, done, info

    def get_action_distr_name(self):
        return 'cb'  # continuous bernoulli
