from typing import Set, List, Tuple, Callable
from itertools import product

import torch
import numpy as np

from dataset import ShuffleDataset
from config import Config
from figure_utils import make_board_from_info
from driving_gridworld.road import Road
from driving_gridworld.car import Car
from driving_gridworld.obstacles import Bump
from driving_gridworld.rewards import reward
from driving_gridworld.actions import ACTIONS


class EnvironmentDataset:
    """
    Convenience class for obtaining everything useful for our experiments in
    a much more convenient interface than what driving_gridworld provides,
    with the sacrifice of less flexibility (made exclusively for
    these experiments).
    """
    def __init__(
            self,
            allowed_columns: List[Set[int]] = [{0, 3}],
            state_to_tensor: Callable = None,
            prob_of_appearing: float = 0.16,
    ):
        """
        :param headlight_range: the maximum distance that the car can see
        ahead. For example, if headlight range = 2,
        then the car can see 2 spaces ahead and as such, the
        driving_gridworld environment is 3 spaces high.
        :param allowed_columns: Where the obstacle can spawn.
        :param state_to_tensor: The transform to use on the state to obtain
        a tensor. Used for state transformation, with the inspired use case
        being the head architecture [model.head].
        """
        self.config = Config()
        self.road = Road(
                self.config.headlight_range,
                Car(2, 0),
                obstacles=[Bump(-1, -1, prob_of_appearing=prob_of_appearing)] * \
                           len(allowed_columns),
                allowed_obstacle_appearance_columns=allowed_columns,
                allow_crashing=False)

        (self.prob_trans_mat,
         self.true_reward,
         state_indices,
         key_to_road,
         self.state_action_next_state_reward) = (
            self.road.tabulate(reward))
        self.true_reward = np.array(self.true_reward)
        self.index_to_road_layers = {
            index: key_to_road[key].ordered_layers()
            for (key, index) in state_indices.items()
        }
        self.index_to_road = {
            index: key_to_road[key].copy()
            for (key, index) in state_indices.items()
        }

        def true_reward_callable(state, action):
            return self.true_reward[state][action]
        self.true_reward_callable = true_reward_callable

        self.index_to_board = {index: make_board_from_info(
            car_col, car_spd, obstacles, self.config.headlight_range)
            for ((car_col, car_spd, obstacles), index) in state_indices.items()}
        self.index_to_info = {index: info for info, index in state_indices.items()}

        self.board_key_to_index = {}
        for (index, board) in self.index_to_board.items():
            self.board_key_to_index[tuple(board.flatten())] = index

        if state_to_tensor is None:
            def state_to_tensor(state):
                layers, spd, row_dist, col_dist = state
                return (
                    layers, torch.argmax(spd).unsqueeze(0),)
        self.state_to_tensor = state_to_tensor


    def obtain_dataset(
            self,
            *args,
            **kwargs,
    ) -> ShuffleDataset:
        """
        Returns the dataset wrapped around StateActionRewardDataset for simple
        compliance with pretty much anything in Pytorch, along with providing
        a variety of convenient properties, where the state representation is the board.
        :param args: The arguments passed into the StateActionRewardDataset constructor,
        with the exception of the dataset positional argument (at position 0).
        :param kwargs: The keyword arguments passed into the
        StateActionRewardDataset constructor.
        :returns: The StateActionRewardDataset used to train reward function approximators.
        """
        nonzero_transitions = []
        for state in range(len(self.prob_trans_mat)):
            board = self.obtain_board_representation(state)
            for action in range(len(self.prob_trans_mat[state])):
                for next_state in range(len(self.prob_trans_mat[state][action])):
                    if self.prob_trans_mat[state][action][next_state] > 0:
                        next_board = self.obtain_board_representation(state)
                        nonzero_transitions.append(
                            (
                                state,
                                next_state,
                                self.state_action_next_state_reward[state][action][next_state],
                            ))

        return ShuffleDataset(
            [(
                (
                    self.state_to_tensor(
                        self.board_to_state(self.obtain_board_representation(state))
                        ),
                    self.state_to_tensor(
                        self.board_to_state(self.obtain_board_representation(next_state))
                        )
                ),
                torch.tensor(rew).type(torch.FloatTensor).to(self.config.device),
            )
             for (state, next_state, rew) in nonzero_transitions],
            *args,
            **kwargs)

    def board_to_state(
            self,
            board: np.array
    ) -> Tuple[np.array, torch.Tensor]:

        state_index = self.board_to_state_index(board)
        car_col, car_spd, obstacles = self.index_to_info[state_index]
        obstacle_present = len(obstacles) > 0
        if obstacle_present:
            _, obs_row, obs_col, _, _ = list(obstacles)[0]
            obs_car_col_dist = car_col - obs_col
            obs_car_row_dist = board.shape[0] - 1 - obs_row
        else:
            obs_car_col_dist = obs_car_row_dist = 0
        obs_car_row_dist = torch.Tensor([obs_car_row_dist]).to(self.config.device)
        obs_car_col_dist = torch.Tensor([obs_car_col_dist]).to(self.config.device)

        speed_vec = board[:, -1]
        board = board[:, 1:-2]
        speed_val = np.sum(speed_vec == 94)
        speed_onehot = torch.zeros(board.shape[0] + 1).to(self.config.device)
        speed_onehot[speed_val] = 1

        ordered_layers = self.index_to_road_layers[state_index]
        layer_repr = []
        for char, layer in ordered_layers:
            if char == '^' or char == '|':
                continue
            layer = layer[:, 1:-2]
            layer_repr.append(layer)
        layer_repr = torch.stack([torch.from_numpy(layer).float()
                                for layer in layer_repr])
        return layer_repr, speed_onehot, obs_car_row_dist, obs_car_col_dist


    def obtain_board_representation(self, state: int):
        """
        :param state: the state index.
        :returns: the board corresponding to the state index.
        """
        return self.index_to_board[state]

    def board_to_state_index(self, board: "Board") -> int:
        """
        :param board: A board representation of some state.
        :returns: The state index corresponding to that board.
        """
        key = tuple(board.flatten())
        return self.board_key_to_index[key]

    @classmethod
    def obtain_train_env(
            cls,
            *args,
            **kwargs,
    ) -> "EnvironmentDataset":
        """
        :returns: The default training environment.
        """
        config = Config()
        return cls(
            allowed_columns=[{0, 3}],
            prob_of_appearing=0.9,
            *args,
            **kwargs,
        )

    @classmethod
    def obtain_test_env(
            cls,
            *args,
            **kwargs,
    ) -> "EnvironmentDataset":
        """
        :returns The default testing environment.
        """
        config = Config()
        return cls(
            allowed_columns=[
                {0, 1, 2, 3},
                ],
            prob_of_appearing=0.9,
            *args,
            **kwargs,
        )
