from typing import Collection, Dict

import numpy as np
import numpy.testing
import torch

from common.ours.models import discretize_actions, interpret_discrete_actions
from common.ours.utils import DictDataset, ant_to_maze2d, maze2d_to_ant


def test_discretize_actions():
    n_bins = 21
    actions = np.array([-1., -0.5, 0., 0.5, 1.])
    res = discretize_actions(actions, n_bins, one_hot=False)
    ans = np.array([0, 5, 10, 15, 20])
    assert res.shape == ans.shape
    np.testing.assert_array_almost_equal(res, ans)

    res = discretize_actions(actions[None], n_bins, one_hot=False)
    ans = np.array([[0, 5, 10, 15, 20]])
    assert res.shape == ans.shape
    np.testing.assert_array_almost_equal(res, ans)

    # torch Tensor
    actions = torch.Tensor([-1., -0.5, 0., 0.5, 1.])
    res = discretize_actions(actions, n_bins, one_hot=False)
    ans = np.array([0, 5, 10, 15, 20])
    assert res.shape == ans.shape
    np.testing.assert_array_almost_equal(res, ans)

    res = discretize_actions(actions[None], n_bins, one_hot=False)
    ans = np.array([[0, 5, 10, 15, 20]])
    assert res.shape == ans.shape
    np.testing.assert_array_almost_equal(res, ans)


def test_interpret_discrete_actions():
    n_bins = 21
    actions = np.array([0, 5, 10, 15, 20])
    res = interpret_discrete_actions(actions, n_bins)
    ans = np.array([-1., -0.5, 0., 0.5, 1.])
    assert res.shape == ans.shape
    np.testing.assert_array_almost_equal(res, ans)

    res = interpret_discrete_actions(actions[None], n_bins)
    ans = np.array([[-1., -0.5, 0., 0.5, 1.]])
    assert res.shape == ans.shape
    np.testing.assert_array_almost_equal(res, ans)

    # torch Tensor
    actions = torch.Tensor([0, 5, 10, 15, 20])
    res = interpret_discrete_actions(actions, n_bins)
    ans = np.array([-1., -0.5, 0., 0.5, 1.])
    assert res.shape == ans.shape
    np.testing.assert_array_almost_equal(res, ans)

    res = interpret_discrete_actions(actions[None], n_bins)
    ans = np.array([[-1., -0.5, 0., 0.5, 1.]])
    assert res.shape == ans.shape
    np.testing.assert_array_almost_equal(res, ans)


def _dict_equal(dict1: Dict[str, Collection], dict2: Dict[str, Collection]):
    for key1, val1 in dict1.items():
        assert key1 in dict2
        np.testing.assert_array_almost_equal(val1, dict2[key1])


def test_dict_dataset():
    data_a = np.array(range(20)).reshape(-1, 5)
    data_b = np.array(range(20)).reshape(-1, 5) + 100
    numpy_data = {'a': data_a, 'b': data_b}
    dataset = DictDataset(data=numpy_data)

    _dict_equal(dataset[0], {
        'a': [0, 1, 2, 3, 4],
        'b': [100, 101, 102, 103, 104]
    })
    _dict_equal(dataset[3], {
        'a': [15, 16, 17, 18, 19],
        'b': [115, 116, 117, 118, 119]
    })


def test_ant_to_maze2d():
    ant_xy_array = np.array([[0, 0], [0, 8], [4, 8]])
    res_array = ant_to_maze2d(ant_xy_array)
    ans_array = np.array([[1, 1], [3, 1], [3, 2]])

    np.testing.assert_array_almost_equal(res_array, ans_array)


def test_maze2d_to_ant():
    maze2d_array = np.array([[1, 1], [3, 1], [3, 2]])
    res_array = maze2d_to_ant(maze2d_array)
    ans_array = np.array([[0, 0], [0, 8], [4, 8]])

    np.testing.assert_array_almost_equal(res_array, ans_array)


def test_cycle_maze_ant():
    length = 1000
    array = np.random.uniform(low=-10, high=10, size=(length, 2))

    res1 = ant_to_maze2d(maze2d_to_ant(array))
    np.testing.assert_array_almost_equal(array, res1)

    res2 = maze2d_to_ant(ant_to_maze2d(array))
    np.testing.assert_array_almost_equal(array, res2)


def test_pointmaze_start_goal_id_conversion():
    import gym

    import d4rl
    from d4rl.pointmaze.maze_model import MazeEnv
    env: MazeEnv = gym.make('maze2d-medium-v1')
    goal_ids = [0, 3, 19]
    start_ids = [0, 1, 3]
    task_ids = [0, 4 * 3 + 1, 4 * 19 + 3]

    for goal_id, start_id, task_id in zip(goal_ids, start_ids, task_ids):
        assert task_id == env.goal_id_and_start_id_to_task_id(
            goal_id=goal_id, start_id=start_id)
        assert goal_id == env.task_id_to_goal_id(task_id=task_id)
        assert start_id == env.task_id_to_start_id(task_id=task_id)

    calculated_range = env.goal_id_to_task_id_list(goal_id=2)
    np.testing.assert_array_equal(calculated_range, (8, 9, 10, 11))

    env: MazeEnv = gym.make('maze2d-umaze-v1')
    goal_ids = [0, 3, 5]
    start_ids = [0, 1, 2]
    task_ids = [0, 3 * 3 + 1, 3 * 5 + 2]

    for goal_id, start_id, task_id in zip(goal_ids, start_ids, task_ids):
        assert task_id == env.goal_id_and_start_id_to_task_id(
            goal_id=goal_id, start_id=start_id)
        assert goal_id == env.task_id_to_goal_id(task_id=task_id)
        assert start_id == env.task_id_to_start_id(task_id=task_id)

    calculated_range = env.goal_id_to_task_id_list(goal_id=2)
    np.testing.assert_array_equal(calculated_range, (6, 7, 8))
