import torch

from gcip.utils.graph import remove_edges_from_batch, remove_nodes_from_batch
from .node_env_base import GraphEnvBase

import gcip.utils.io as playbook_io
import numpy as np


class GraphEnv(GraphEnvBase):
    """Removes one node at a time"""

    def __init__(self, *args, **kwargs):
        '''

        Args:
            loader:
            graph_clf:
            reward_fn:
            desired_perc_nodes:
            max_episode_length: This is positive integer that controls the maximum number of steps an episode can have
            use_intrinsic_reward: This is a boolean that controls if we use intermediate (intrinsic) rewards  or not
            device:
        '''
        super(GraphEnv, self).__init__(*args, **kwargs)

    @staticmethod
    def kwargs(cfg, preparator, graph_clf, reward):
        my_dict = {}
        my_dict.update(GraphEnvBase.kwargs(cfg, preparator, graph_clf, reward))
        return my_dict

    def _remove(self, state, idx_to_remove, relabel_nodes):
        if self.action_refers_to == 'edge':
            data = remove_edges_from_batch(edges_to_remove=idx_to_remove,
                                           batch=state.clone(),
                                           relabel_nodes=relabel_nodes)
            return data

        elif self.action_refers_to == 'node':
            data = remove_nodes_from_batch(batch=state.clone(),
                                           nodes_idx=idx_to_remove,
                                           relabel_nodes=relabel_nodes)

            return data
        else:
            raise NotImplementedError

    def _step(self, action, relabel_nodes=True):
        # Execute one time step within the environment
        assert all([a in [0.0, 1.0] for a in action.unique()])

        idx_to_remove = list(np.where(action.cpu() == 1)[0])
        num_idx_to_remove = int(action.sum().item())

        info = {'Elements to remove': 1,
                'Finish?': 'No',
                'current_iter': self.current_iter}

        # print(f"{self.current_iter} {num_nodes} {node_to_remove}")
        if num_idx_to_remove > 0:  # We need to remove some nodes
            # playbook_io.print_debug(f"num_idx_to_remove: {num_idx_to_remove}")
            # playbook_io.print_debug(f"idx_to_remove: {idx_to_remove}")
            # playbook_io.print_debug(f"action: {action}")
            # playbook_io.print_debug(f"self.state: {self.state}")
            data = self._remove(state=self.state,
                                idx_to_remove=idx_to_remove,
                                relabel_nodes=relabel_nodes)
            # playbook_io.print_debug(f"data: {data}")

            # print(f"data: {data}")
            if self._is_empty(data):
                done = True
                # print('YES')
                info['Finish?'] = 'Yes, result would have no nodes/edges'
                new_state = self.state

                reward = self.reward.compute(state=self.state, action=action)
                penalty = self.penalty_size * torch.ones(1, device=self.device)
                reward -= penalty

                return new_state, reward, done, info
            else:
                self.state = data
                reward = self.reward.compute(state=self.state, action=action)

        elif self._is_original(self.state):

            reward = self.reward.compute(state=self.state, action=action)
            penalty = self.penalty_size * torch.ones(1, device=self.device)
            reward -= penalty
        else:
            reward = self.reward.compute(state=self.state, action=action)

        intrinsic_reward = torch.zeros(1, device=self.device)

        self.current_reward = reward  # - self.init_reward

        if (num_idx_to_remove == 0) or self.current_iter == self.max_episode_length:
            done = True
            info['Finish?'] = 'Yes, No elements to remove'
            return self.state, self.current_reward, done, info
        else:
            done = False
            return self.state, intrinsic_reward, done, info

    def get_action_distr_name(self):
        return 'ber'  # continuous bernoulli
