import collections
import itertools
import os

import fire
from omegaconf import OmegaConf
import torch

from offline_rl.scripts.rewards.learning.common import get_env, get_gym_model
from offline_rl.utils.space_utils import get_space_size


class LineEnvVisualizer:
    def __init__(self, env, model, output_dir):
        self.env = env
        self.model = model
        self.output_dir = output_dir

    def visualize(self):
        num_states = get_space_size(self.env.observation_space)
        num_actions = get_space_size(self.env.action_space)

        states = list(range(num_states))
        actions = list(range(num_actions))
        next_states = states

        # You create the cross product, and then reverse zip the individual tuples.
        states, actions, next_states = zip(*itertools.product(states, actions, next_states))

        states = torch.LongTensor(states)
        actions = torch.LongTensor(actions)
        next_states = torch.LongTensor(next_states)

        rewards = self.model.reward(states, actions, next_states, None)

        # Convert to a dictionary format to be consistent with other line env plotting functions.
        reward_dict = dict()
        for state, action, next_state, reward in zip(states, actions, next_states, rewards):
            key = (int(state), int(action), int(next_state))
            reward_dict[key] = float(reward)

        self.env.render_reward_function(reward_dict)


class MazeEnvVisualizer:
    def __init__(self, env, model, output_dir):
        self.env = env
        self.model = model
        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir

    def visualize(self):
        transitions = self.env.state_action_next_state_transitions()
        states, actions, next_states = zip(*transitions)
        states = torch.LongTensor(states)
        actions = torch.LongTensor(actions)
        next_states = torch.LongTensor(next_states)
        rewards = self.model.reward(states, actions, next_states, None)

        # Convert to a dictionary format to be consistent with other line env plotting functions.
        reward_dict = collections.defaultdict(dict)
        for state, action, reward in zip(states, actions, rewards):
            state_key = (int(state[0]), int(state[1]))
            action_key = int(action)
            assert action_key not in reward_dict[state_key], "(s,a) pair already inserted."
            reward_dict[state_key][action_key] = float(reward)

        reward_plot_filepath = os.path.join(self.output_dir, "rewards.png")
        self.env.render_reward_function(reward_dict, reward_plot_filepath)


class BouncingBallsEnvVisualizer:
    def __init__(self, env, model, output_dir):
        self.env = env
        self.model = model
        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir

    def visualize(self):
        raise NotImplementedError("See jupyter notebook")


def get_visualizer_class(env_name):
    if env_name == "LineEnv-v0":
        return LineEnvVisualizer
    elif env_name == "MazeEnv-v0":
        return MazeEnvVisualizer
    elif env_name == "BouncingBallsEnv-v0":
        return BouncingBallsEnvVisualizer
    else:
        raise ValueError(f"Visualization of env not supported: {env_name}")


def make_reward_visualizer(config_filepath, checkpoint_filepath, output_dir=None):
    config = OmegaConf.load(config_filepath)
    visualizer_cls = get_visualizer_class(config.data.env_name)
    env = get_env(config.data.env_name)
    model = get_gym_model(env, config.model, checkpoint_filepath)
    return visualizer_cls(env, model, output_dir)


if __name__ == "__main__":
    fire.Fire(make_reward_visualizer)
