import numpy as np
import torch
import gym
from datasets.constants import _DATASET_CLASSES, _DATASET_ROOTS
from datasets._configs import CONFIGS
from torch_geometric.loader import DataLoader
from sb3_contrib.common.envs import InvalidActionEnvMultiDiscrete


def graph_to_dict(data):
    return {
        "x": data.x,
        "visited_nodes": data.visited_nodes,
        "start_route": data.start_route,
        "edge_index": data.edge_index,
        "edge_attr": data.edge_attr,
        "current_node": data.start_route.argmax().item(),
        "lengths": data.lengths,
        "optimal_value": data.optimal_value
    }


class TSPDistCost(InvalidActionEnvMultiDiscrete):
    def __init__(self,
                 split,
                 mask_invalid_actions=False,
                 optimisation_mode="one_sample"):

        assert optimisation_mode in ['one_sample', 'multiple_samples']

        self.dataset = _DATASET_CLASSES['tsp'](
            _DATASET_ROOTS['tsp'],
            num_nodes=CONFIGS['tsp'][split]['num_nodes'],
            num_samples=CONFIGS['tsp'][split]['num_samples']

        )

        self.N = CONFIGS['tsp'][split]['num_nodes']
        self.invalid_action_cost = -100
        self.use_mask = mask_invalid_actions
        self.nodes = np.arange(self.N)
        self.obs_dim = self.N
        self.action_dim = self.N

        self.loader = iter(DataLoader(self.dataset,
                                      batch_size=1,
                                      shuffle=True))

        self.observation_space = gym.spaces.Dict({
            "x": gym.spaces.Box(low=0, high=1, shape=(self.N, 2)),
            "visited_nodes": gym.spaces.Box(low=0, high=1, shape=(self.N,)),
            "start_route": gym.spaces.Box(low=0, high=1, shape=(self.N,)),
            "edge_index": gym.spaces.Box(low=0, high=self.N-1, shape=(2, self.N**2)),
            "edge_attr": gym.spaces.Box(low=0, high=np.inf, shape=(self.N**2,)),
            "current_node": gym.spaces.Discrete(self.N),
            "lengths": gym.spaces.Box(low=0, high=np.inf, shape=(1,)),
            "optimal_value": gym.spaces.Box(low=0, high=np.inf, shape=(1,))
        })
        self.action_space = gym.spaces.Discrete(self.N)

        self.is_one_sample_mode = optimisation_mode == 'one_sample'
        if self.is_one_sample_mode:
            self.item = next(iter(self.loader))

    def action_masks(self):
        return torch.FloatTensor(self.visit_log)

    def step(self, action):
        done = False
        if self.visit_log[action] > 0:
            # Node already visited
            self.cumulative_reward = self.invalid_action_cost
            done = True
        else:
            self.cumulative_reward += \
                self.state.edge_attr.reshape(self.N, self.N)[self.current_node, action].item()
            self.current_node = action
            self.visit_log[self.current_node] = 1

        self.state = self.state.clone()
        self.state.start_route = torch.nn.functional.one_hot(
            torch.LongTensor(np.array([self.current_node])), self.N).float().squeeze()
        self.state.visited_nodes = torch.FloatTensor(self.visit_log)
        # See if all nodes have been visited
        unique_visits = self.visit_log.sum()
        if unique_visits == self.N:
            done = True

        # Sparse reward
        reward = 0
        if done:
            reward = self.cumulative_reward/self.state.optimal_value.item() - 1
            reward = -reward

        return graph_to_dict(self.state), reward, done, {}

    def reset(self):
        if not self.is_one_sample_mode:
            try:
                state = next(self.loader).clone()
            except StopIteration:
                self.loader = iter(DataLoader(self.dataset,
                                              batch_size=1,
                                              shuffle=True))
                state = next(self.loader).clone()
        else:
            state = self.item.clone()

        self.step_count = 0
        self.current_node = state.start_route.argmax().item()
        self.visit_log = np.zeros(self.N)
        self.visit_log[self.current_node] += 1
        self.state = state
        self.state.visited_nodes = torch.FloatTensor(self.visit_log)
        self.cumulative_reward = 0

        return graph_to_dict(self.state)
