from typing import Callable, Dict, List

import gym
import h5py
import numpy as np
from gym.spaces import Box
from omegaconf import DictConfig

# yapf: disable
goal_to_id = {
    "maze2d": {
        "umaze": {(1, 1): 1, (1, 2): 2, (1, 3): 3, (2, 3): 4, (3, 1): 5, (3, 2): 6, (3, 3): 7,},
        "medium": {(1, 1):1, (1, 2):2, (1, 5):3, (1, 6):4, (2, 1):5, (2, 2):6, (2, 4):7, (2, 5):8, (2, 6):9, (3, 2):10, (3, 3):11, (3, 4):12, (4, 1):13, (4, 2):14, (4, 4):15, (4, 5):16, (4, 6):17, (5, 1):18, (5, 3):19, (5, 4):20, (5, 6):21, (6, 1):22, (6, 2):23, (6, 3):24, (6, 5):25, (6, 6):26,},
        "large": {(1, 1):1, (1, 2):2, (1, 3):3, (1, 4):4, (1, 6):5, (1, 7):6, (1, 8):7, (1, 9):8, (1, 10):9, (2, 1):10, (2, 4):11, (2, 6):12, (2, 8):13, (2, 10):14, (3, 1):15, (3, 2):16, (3, 3):17, (3, 4):18, (3, 5):19, (3, 6):20, (3, 8):21, (3, 9):22, (3, 10):23, (4, 1):24, (4, 6):25, (4, 10):26, (5, 1):27, (5, 2):28, (5, 4):29, (5, 6):30, (5, 7):31, (5, 8):32, (5, 9):33, (5, 10):34, (6, 2):35, (6, 4):36, (6, 6):37, (6, 8):38, (7, 1):39, (7, 2):40, (7, 4):41, (7, 5):42, (7, 6):43, (7, 8):44, (7, 9):45, (7, 10):46,},
    },
    "point": {
        "umaze": {(0, 0):1, (0, 8):2, (4, 0):3, (4, 8):4, (8, 0):5, (8, 4):6, (8, 8):7,},
        "medium": {(0, 0):1, (0, 4):2, (0, 12):3, (0, 16):4, (0, 20):5, (4, 0):6, (4, 4):7, (4, 8):8, (4, 12):9, (4, 20):10, (8, 8):11, (8, 16):12, (8, 20):13, (12, 4):14, (12, 8):15, (12, 12):16, (12, 16):17, (16, 0):18, (16, 4):19, (16, 12):20, (16, 20):21, (20, 0):22, (20, 4):23, (20, 12):24, (20, 16):25, (20, 20):26,},
        "large": {(0, 0):1, (0, 4):2, (0, 8):3, (0, 12):4, (0, 16):5, (0, 24):6, (4, 0):7, (4, 8):8, (4, 16):9, (4, 20):10, (4, 24):11, (8, 0):12, (8, 8):13, (12, 0):14, (12, 4):15, (12, 8):16, (12, 16):17, (12, 20):18, (12, 24):19, (16, 8):20, (16, 24):21, (20, 0):22, (20, 4):23, (20, 8):24, (20, 12):25, (20, 16):26, (20, 20):27, (20, 24):28, (24, 0):29, (24, 16):30, (28, 0):31, (28, 4):32, (28, 8):33, (28, 16):34, (28, 20):35, (28, 24):36, (32, 0):37, (32, 8):38, (32, 16):39, (32, 24):40, (36, 0):41, (36, 4):42, (36, 8):43, (36, 12):44, (36, 16):45, (36, 24):46,},
    }
}
# yapf: enable


def get_task_id(target, env_id):
    target = np.array(target)
    maze_type = env_id.split("-")[1]
    if "maze2d" in env_id:
        target = tuple((target + 0.5).astype("int"))
        return goal_to_id["maze2d"][maze_type][target]
    elif "point" in env_id or "ant" in env_id:
        target = tuple(((target + 2) / 4).astype("int") * 4)
        return goal_to_id["point"][maze_type][target]
    else:
        raise ValueError("Unrecognized env_id " + env_id)


class CCATransferWrapper(gym.Wrapper):

    def __init__(
        self,
        env: gym.Env,
        args: DictConfig,
        source_dataset: Dict[str, np.ndarray],
        source_transform: Callable,
        target_transform: Callable,
        inference_task_ids: List[int],
        num_task_ids: int,
        use_task_id_for_obs: bool,
        alpha: float,
        aux_reward_only: bool,
    ):
        """cca transfer wrapper

        Args:
            env (gym.Env): original environment

            args (DictConfig): args

            source_dataset (Dict[str, np.ndarray]):
                dataset from source domain, specific keys ("terminals", "observations", "infos/goal_id") are required

            source_transform (Callable):
                transform function that maps source observation to latent representation

            target_transform (Callable):
                transform function that maps target observation to latent representation

            inference_task_ids (List[int]):
                task ids for inference tasks

            num_task_ids (int):
                number of all task ids

            use_task_id_for_obs (bool):
                if True, concatenate task id and observation

            alpha (float):
                weight for auxiliary reward

            aux_reward_only (bool):
                if True, use auxiliary reward only
        """

        super().__init__(env)
        terminals = source_dataset["terminals"]
        terminals[-1] = True

        starts = np.full_like(terminals, False)
        starts[0] = True
        starts[1:] = terminals[:-1]

        start_idxs = np.where(starts == True)
        terminal_idxs = np.where(terminals == True)

        task_ids = source_dataset["infos/goal_id"]

        observations = source_dataset["observations"]
        transformed = source_transform(observations)

        source_latents = {}
        for task_id in inference_task_ids:
            source_latents[task_id] = []
            mask = (task_ids == task_id)
            start_idxs = np.where(starts * mask == True)[0]
            terminal_idxs = np.where(terminals * mask == True)[0]

            for i in range(len(start_idxs)):
                source_latents[task_id].append(
                    transformed[start_idxs[i]:terminal_idxs[i]])

            episodes_length = [len(ep) for ep in source_latents[task_id]]
            max_episode_length = np.max(episodes_length)
            source_latents[task_id] = [
                np.concatenate(
                    (ep,
                     np.tile(ep[-1][None],
                             (max_episode_length - ep.shape[0], 1))))
                for ep in source_latents[task_id]
            ]
            source_latents[task_id] = np.array(source_latents[task_id])

        self.source_latents = source_latents
        self.target_transform = target_transform
        self.args = args
        self.env_id = self.args.target_env_id
        self.alpha = alpha
        self.aux_reward_only = aux_reward_only

        self.inference_task_ids = inference_task_ids
        self.num_task_ids = num_task_ids
        self.task_id = inference_task_ids[0]
        self.use_task_id_for_obs = use_task_id_for_obs

        if self.use_task_id_for_obs:
            self.onehot = np.eye(self.num_task_ids + 1)[self.task_id]
            state_dim = self.env.observation_space.shape[0] + num_task_ids + 1
            self._observation_space = Box(-np.ones((state_dim, )),
                                          np.ones((state_dim, )),
                                          shape=(state_dim, ))

        self.t = 0

    def reset(self, **kwargs):
        self.env.reset()
        task_id = np.random.choice(self.inference_task_ids)
        target = self.env.id_to_xy[task_id]

        self.task_id = task_id

        self.env.set_target(target)
        self.env.set_init_xy()
        self.t = 0
        obs = self.get_obs()

        if self.use_task_id_for_obs:
            self.onehot = np.eye(self.num_task_ids + 1)[self.task_id]
            obs = np.concatenate((obs, self.onehot), -1)

        return obs

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        target_latents = self.target_transform(obs)

        task_id = get_task_id(self.env.get_target(), self.env_id)

        num_episodes = self.source_latents[task_id].shape[0]
        max_length = self.source_latents[task_id].shape[1]

        mse = (self.source_latents[task_id][:, self.t] - target_latents)**2

        aux_rew = -(np.sum(mse) / num_episodes) * self.alpha

        if self.aux_reward_only:
            rew = aux_rew
        else:
            rew += aux_rew

        if self.t < max_length - 1:
            self.t += 1

        if self.use_task_id_for_obs:
            obs = np.concatenate((obs, self.onehot), -1)

        return obs, rew, done, info
