import copy
from abc import ABC, abstractmethod

import gym
import torch
from torchlikelihoods.scalers import IdentityScaler

from torch_geometric.data import Batch

class GraphEnvBase(gym.Env, ABC):
    """Removes one node at a time"""
    metadata = {'render.modes': ['human']}

    def __init__(self, loader,
                 graph_clf,
                 reward,
                 input_scaler=None,
                 action_refers_to=None,
                 penalty_size=0.0,
                 max_episode_length=100000,
                 use_intrinsic_reward=False,
                 device='cpu'):
        '''

        Args:
            loader:
            graph_clf:
            reward_fn:
            desired_perc_nodes:
            max_episode_length: This is positive integer that controls the maximum number of steps an episode can have
            use_intrinsic_reward: This is a boolean that controls if we use intermediate (intrinsic) rewards  or not
            device:
        '''
        super(GraphEnvBase, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        assert loader.batch_size == 1
        assert penalty_size >= 0
        assert use_intrinsic_reward == False
        assert action_refers_to in ['node', 'edge']
        # Example for using image as input:
        self.observation_space = None
        self.loader = loader
        if input_scaler is None:
            self.input_scaler = IdentityScaler()
        else:
            self.input_scaler = input_scaler
        self.graph_clf = graph_clf
        self.reward = reward
        self.state = None
        self.current_iter = 0
        self.current_reward = 0
        self.max_episode_length = max_episode_length
        self.use_intrinsic_reward = use_intrinsic_reward
        self.penalty_size = penalty_size
        self.action_refers_to = action_refers_to

        self._node_feature = 'x'
        self._edge_attr = None

        self.init_num_nodes = None
        self.init_num_edges = None

        if device == 'auto':
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device

    @staticmethod
    def kwargs(cfg, preparator, graph_clf, reward):

        my_dict = {}
        my_dict['loader'] = preparator.get_dataloaders(batch_size=1)[1]
        my_dict['input_scaler'] = preparator.get_scaler(fit=True)

        my_dict['graph_clf'] = graph_clf
        my_dict['reward'] = reward

        my_dict['action_refers_to'] = cfg.env.action_refers_to
        my_dict['penalty_size'] = cfg.env.penalty_size
        my_dict['max_episode_length'] = cfg.env.max_episode_length
        my_dict['use_intrinsic_reward'] = cfg.env.use_intrinsic_reward

        my_dict['device'] = cfg.device

        return my_dict

    def transform(self, batch, inplace=False):
        return self.input_scaler.transform(batch.to(self.device),
                                           inplace=inplace)

    def __get_edge_attr(self, data):
        if self._edge_attr is not None:
            return getattr(data, self._node_feature)
        else:
            return None

    def _is_empty(self, data):
        if self.action_refers_to == 'edge':
            return data.edge_index.shape[1] == 0
        elif self.action_refers_to == 'node':
            return data.x.shape[0] == 0

    @abstractmethod
    def _step(self, action, relabel_nodes=True):
        pass

    def step(self, action, relabel_nodes=True):
        self.validate_action(action)
        self.current_iter += 1
        return self._step(action, relabel_nodes)

    def validate_action(self, action):
        num_elements = self._num_elements(self.state)
        assert len(action) == num_elements

    @abstractmethod
    def get_action_distr_name(self):
        pass

    def get_final_graph(self):
        return self.state

    def on_episode_end(self):
        if self.is_training:
            self.graph_clf.train()
        else:
            self.graph_clf.eval()


    def _is_original(self, data):
        return self._num_elements(data) == self.init_num_elements

    @property
    def init_num_elements(self):
        if self.action_refers_to == 'edge':
            return self.init_num_edges
        elif self.action_refers_to == 'node':
            return self.init_num_nodes

    def _num_elements(self, data):
        if self.action_refers_to == 'edge':
            return data.edge_index.shape[1]
        elif self.action_refers_to == 'node':
            return data.x.shape[0]

    def reset(self, graph=None, idx=None, transform=True):
        # Reset the state of the environment to an initial state
        self.is_training = self.graph_clf.training
        self.graph_clf.eval()
        del self.state
        if graph is not None:
            state = graph.clone()
            state.batch = torch.zeros(state[self._node_feature].shape[0],
                                      dtype=torch.int64,
                                      device=self.device)
        elif idx is not None:
            data = self.loader.dataset.__getitem__(idx)
            state = copy.deepcopy(Batch.from_data_list([data]))

        else:
            state = copy.deepcopy(next(iter(self.loader)))
        if transform:
            self.state = self.transform(state, inplace=False)
        else:
            self.state = state.clone()

        self.reward.set_state_0(self.state.clone())

        self.current_iter = 0
        self.current_reward = 0
        self.init_num_nodes = self.state[self._node_feature].shape[0]
        self.init_num_edges = self.state.edge_index.shape[1]
        self.init_reward = self.reward.compute(state=self.state, action=None).item()
        return self.state

    def render(self, mode='human', close=False):
        # Render the environment to the screen
        print(f'\nNumber of nodes: {self.state[self._node_feature].shape[0]}')
        print(f'Iteration: {self.current_iter}')
        print(f'Reward: {self.current_reward}')
        return

    def __str__(self):
        my_str = f"\nGraph Environment"
        my_str += f"\n\tMax episode length: {self.max_episode_length}"
        my_str += f"\n\tUse intrinsic reward: {self.use_intrinsic_reward}"
        my_str += f"\n\tNumber of graphs in the environment: {len(self.loader)}"
        return my_str
