from typing import Dict, Optional

import gym
from omegaconf import dictconfig
import torch

from offline_rl.envs.load_custom_envs import load_custom_envs
from offline_rl.envs.bouncing_balls_env import BouncingBallsEnvRewardModel
from offline_rl.envs.custom_reacher_env import CustomReacherEnvRewardModel
from offline_rl.rewards.learning.direct_regression_reward_model import DirectRegressionRewardModel
from offline_rl.rewards.learning.discriminative_reward_model import DiscriminativeRewardModel
from offline_rl.rewards.learning.env_specific_networks import (
    LearnedBouncingBallsEnvRewardModel,
    LearnedReacherRewardModel,
)
from offline_rl.rewards.learning.preference_based_reward_model import PreferenceBasedRewardModel
from offline_rl.rewards.learning.reward_model_networks import FullyConnectedRewardModel
from offline_rl.envs.point_maze_env import PointMazeEnvRewardModel
from offline_rl.rewards.reward_model import RewardModel


def get_env(env_name: str):
    """Builds an environment with the provided name.

    Args:
        env_name: The name of the environment.

    Returns:
        A gym environment with the provided name.
    """
    load_custom_envs()
    return gym.make(env_name)


def get_label_overwriting_reward_model(
        cls: str,
        obs_space: gym.spaces.Space,
        act_space: gym.spaces.Space,
        model_kwargs: Dict,
) -> RewardModel:
    """Gets a reward model given its class name and arguments to pass to it.

    Args:
        cls: The name of the reward model class.
        obs_space: The observation space of the environment.
        act_space: The action space of the environment.
        env: The env to make the reward model for.
        model_kwargs: The keyword arguments to provide to the reward model.

    Returns:
        A reward model.
    """
    if cls == "BouncingBallsEnvRewardModel":
        return BouncingBallsEnvRewardModel(
            obs_space,
            act_space,
            **model_kwargs,
        )
    elif cls == "PointMazeEnvRewardModel":
        return PointMazeEnvRewardModel(obs_space, act_space)
    elif cls == "CustomReacherEnvRewardModel":
        return CustomReacherEnvRewardModel(
            obs_space,
            act_space,
            **model_kwargs,
        )
    else:
        raise ValueError(f"Invalid reward model class: {cls}")


def get_gym_model(
        obs_space: gym.spaces.Space,
        act_space: gym.spaces.Space,
        config: dictconfig.DictConfig,
        checkpoint_filepath: Optional[str] = None,
):
    """Builds a reward learning model.

    Args:
        obs_space: The observation space of the environment.
        act_space: The action space of the environment.
        config: The config object containing information necessary to build the model.
        checkpoint_filepath: If provided, load the model parameters from this file.
            This is intended for evaluation purposes rather than restoring training.

    Returns:
        A reward model.
    """
    if config.model_type == "fully_connected":
        submodel_cls = FullyConnectedRewardModel
        submodel_kwargs = dict(obs_space=obs_space, act_space=act_space, hidden_sizes=config.hidden_sizes)
    elif config.model_type == "learned_bouncing_balls_env":
        submodel_cls = LearnedBouncingBallsEnvRewardModel
        submodel_kwargs = dict(obs_space=obs_space, act_space=act_space, hidden_sizes=config.hidden_sizes)
    elif config.model_type == "learned_reacher":
        submodel_cls = LearnedReacherRewardModel
        submodel_kwargs = dict(obs_space=obs_space, act_space=act_space, hidden_sizes=config.hidden_sizes)
    else:
        raise ValueError(f"Invalid model type: {config.model_type}")

    if config.reward_type == "discriminative":
        model = DiscriminativeRewardModel(
            submodel_cls=submodel_cls,
            submodel_kwargs=submodel_kwargs,
            learning_rate=config.learning_rate,
        )
    elif config.reward_type == "direct_regression":
        if "label_overwriting_reward_model_class" in config and config.label_overwriting_reward_model_class is not None:
            label_overwriting_reward_model = get_label_overwriting_reward_model(
                config.label_overwriting_reward_model_class,
                obs_space,
                act_space,
                config.label_overwriting_reward_model_kwargs,
            )
        else:
            label_overwriting_reward_model = None
        model = DirectRegressionRewardModel(
            submodel_cls=submodel_cls,
            submodel_kwargs=submodel_kwargs,
            label_overwriting_reward_model=label_overwriting_reward_model,
            learning_rate=config.learning_rate,
        )
    elif config.reward_type == "preference_based":
        target_model = get_label_overwriting_reward_model(
            config.target_reward_model_class,
            obs_space,
            act_space,
            config.target_reward_model_kwargs,
        )
        model = PreferenceBasedRewardModel(
            submodel_cls=submodel_cls,
            submodel_kwargs=submodel_kwargs,
            target_model=target_model,
            reward_reg_weight=config.reward_reg_weight,
            learning_rate=config.learning_rate,
        )
    else:
        raise ValueError(f"Invalid reward type: {config.reward_type}")

    if checkpoint_filepath is not None:
        model.load_state_dict(torch.load(checkpoint_filepath, map_location=torch.device('cpu'))["state_dict"])
        model.eval()

    return model
