from functools import partial
import os

import hydra
from omegaconf import dictconfig, OmegaConf
from ray.rllib.policy.sample_batch import SampleBatch
import torch
import tqdm

from offline_rl.data.sample_batch_json_reader_dataset import SampleBatchJsonReaderDataset
from offline_rl.envs.bouncing_balls_env import BouncingBallsEnvRewardModel, BouncingBallsEnvFeasibilityRewardWrapper
from offline_rl.envs.custom_reacher_env import CustomReacherEnvRewardModel
from offline_rl.envs.line_env import LineEnvReward, LineEnvRightwardPotential
from offline_rl.envs.point_maze_env import PointMazeEnvRewardModel
from offline_rl.rewards.evaluation.distances import (
    compute_distance_between_reward_pairs,
    compute_direct_distance,
    compute_pearson_distance,
)
from offline_rl.rewards.evaluation.epic import EPIC, compute_scale_normalized_rewards
from offline_rl.rewards.evaluation.model_collection import ModelCollection
from offline_rl.rewards.evaluation.reward_collection import RewardCollection
from offline_rl.rewards.evaluation.bouncing_balls_env_transition_samplers import (
    BouncingBallsEnvActionSamplingTransitionSampler,
    BouncingBallsEnvConstantVelocityTransitionSampler,
)
from offline_rl.rewards.evaluation.line_env_transition_samplers import LineEnvUniformPolicyTransitionSampler
from offline_rl.rewards.evaluation.mujoco_transition_sampler import MujocoTransitionSampler
from offline_rl.rewards.evaluation.point_maze_env_transition_samplers import (
    PointMazeEnvActionSamplingTransitionSampler, )
from offline_rl.rewards.evaluation.transition_sampler import (
    FixedDistributionTransitionSampler,
    BoundaryActionSampler,
    LinearActionSampler,
    UniformlyRandomActionSampler,
)
# pylint: disable=unused-import
from offline_rl.rewards.noisy_reward_wrapper import NoisyRewardWrapper
# pylint: enable=unused-import
from offline_rl.scripts.rewards.learning.common import get_env, get_gym_model
import offline_rl.utils.space_utils as space_utils
# pylint: disable=unused-import
from offline_rl.utils.testing.rewards import ConstantRewardModel, RandomRewardModel
# pylint: enable=unused-import


def get_data_loader(config):
    config = config.data
    state_transform = None if config.transform is None else space_utils.get_numpy_space_transform(config.transform)
    dataset = SampleBatchJsonReaderDataset(
        config.dataset_filepath,
        state_transform=state_transform,
        debug_size=config.debug_size,
        debug_size_mode=config.debug_size_mode,
    )
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle)
    return data_loader


def get_reward_models_for_env(obs_space, act_space, config):
    """Gets (typically manually defined) reward models associated with a specific environment."""
    if config.rewards.env_name == "LineEnv-v0":
        reward_models = {
            "ground_truth": LineEnvReward.make_ground_truth_reward(),
            "rightward_shaped_ground_truth": LineEnvRightwardPotential(LineEnvReward.make_ground_truth_reward()),
            "reverse": LineEnvReward.make_reverse_reward(),
            "zero": LineEnvReward.make_zero_reward(),
            "center": LineEnvReward.make_center_reward(),
        }
    elif config.rewards.env_name == "BouncingBallsEnv-v0":
        # An action magnitude scale of 1 dominates the other rewards to the extent that they don't matter.
        # pylint: disable=unused-variable
        actmag_scale = 0.01
        reward_models = {
            "goal_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, 0, 0, 0, 0, 0),
            # "goal_+1_coll_-0.1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, -0.1, 0, 0.0, 0, 0),
            # "goal_+1_coll_-0.1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, -0.1, 0, 0.0, 0, 0),
            # f"actmag_-{actmag_scale}": BouncingBallsEnvRewardModel(obs_space, act_space, 0, 0, -actmag_scale, 0, 0, 0),
            # "goal_+1_coll_-1_actmag_-1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, -1, -1, 0, 0, 0),
            # "goal_+1_coll_-1_actmag_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, -1, 1, 0, 0, 0),
            # "goal_+1_coll_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, 1, 0, 0, 0, 0),
            # "goal_+1_coll_+1_actmag_-1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, 1, -1, 0, 0, 0),
            # "goal_+1_coll_+1_actmag_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, 1, 1, 0, 0, 0),
            "goal_+1_shaping_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, 0, 0, 0, 1, 0.95),
            # "goal_+1_shaping_-1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, 0, 0, 0, -1, 1.0),
            # "goal_dist_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 0, 0, 0, 1, 0, 0),
            # "goal_dist_+1_shaping_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 0, 0, 0, 1, 1.0, 1.0),
            # "goal_+1_coll_-0.1_shaping_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 1, -0.1, 0, 0.0, 1.0, 1.0),
            # f"actmag_+{actmag_scale}": BouncingBallsEnvRewardModel(obs_space, act_space, 0, 0, actmag_scale, 0, 0, 0),
            # f"actmag_-{actmag_scale}_shaping_+1": BouncingBallsEnvRewardModel(obs_space, act_space, 0, 0, -actmag_scale, 0, 1.0, 1.0),
        }
        base_model = "goal_+1_shaping_+1"
        # reward_models[f"feasibility_constant_{feasibility_base_model}"] = BouncingBallsEnvFeasibilityRewardWrapper(
        #     base_reward=reward_models[feasibility_base_model],
        #     alternative_reward=ConstantRewardModel(1.0),
        #     dt=env.dt,
        #     max_position_error=10.0,
        # )
        reward_models[f"feas_random_{base_model}"] = BouncingBallsEnvFeasibilityRewardWrapper(
            base_reward=reward_models[base_model],
            alternative_reward=RandomRewardModel(),
            dt=0.2,
            max_position_error=10.0,
        )
        # feasibility_alternative_model = BouncingBallsEnvRewardModel(obs_space, act_space, 0, 0, 0, -1, 0, 0)
        # reward_models[f"feasibility_state_dep_{feasibility_base_model}"] = BouncingBallsEnvFeasibilityRewardWrapper(
        #     base_reward=reward_models[feasibility_base_model],
        #     alternative_reward=feasibility_alternative_model,
        #     dt=env.dt,
        #     max_position_error=10.0,
        # )
        for sigma in [1e-3, 1e-2, 1e-1, 1]:
            reward_models[f"noisy_{sigma}_{base_model}"] = NoisyRewardWrapper(
                base_reward=reward_models[base_model],
                sigma=sigma,
            )
    elif "PointMaze" in config.rewards.env_name:
        reward_models = {"ground_truth": PointMazeEnvRewardModel(obs_space, act_space)}

        base_model = "ground_truth"
        for sigma in [1e-4, 1e-3, 1e-2, 1e-1, 1]:
            reward_models[f"noisy_{sigma}_{base_model}"] = NoisyRewardWrapper(
                base_reward=reward_models[base_model],
                sigma=sigma,
            )
    elif config.rewards.env_name == "CustomReacherEnv-v0":
        reward_models = {
            # "ground_truth": CustomReacherEnvRewardModel(obs_space, act_space, 0, 0, 1, 0, 0),
            # "ground_truth_shaping": CustomReacherEnvRewardModel(obs_space, act_space, 0, 0, 1, 10, 1),
            "ground_truth": CustomReacherEnvRewardModel(obs_space, act_space, 1, 1, 1, 0, 0),
            "ground_truth_shaping": CustomReacherEnvRewardModel(obs_space, act_space, 1, 1, 1, 10, 0.95),
        }
        base_model = "ground_truth_shaping"
        for sigma in [1e-3, 1e-2, 1e-1, 1]:
            reward_models[f"noisy_{sigma}_{base_model}"] = NoisyRewardWrapper(
                base_reward=reward_models[base_model],
                sigma=sigma,
            )
    else:
        raise ValueError(f"No reward models for environment name: {config.rewards.env_name}")
    return reward_models


def get_learned_reward_models(obs_space, act_space, config):
    if "learned" not in config.rewards or config.rewards.learned is None:
        return {}

    models = {}
    for key in config.rewards.learned.keys():
        model_filepaths = config.rewards.learned[key]
        config_filepath = model_filepaths["config_filepath"]
        checkpoint_filepath = model_filepaths["checkpoint_filepath"]
        model_config = OmegaConf.load(config_filepath)
        model = get_gym_model(
            obs_space,
            act_space,
            model_config.model,
            checkpoint_filepath,
        )
        model.to(config.common.device)
        models[key] = model
    return models


def get_reward_models(obs_space, act_space, config):
    reward_models = get_reward_models_for_env(obs_space, act_space, config)
    learned_reward_models = get_learned_reward_models(obs_space, act_space, config)
    reward_models.update(learned_reward_models)
    return ModelCollection(reward_models)


def get_fixed_distribution_transition_sampler(config):
    # Get an independent data loader from that used for sampling the first set of transitions.
    if config.data.debug_size is not None and config.data.debug_size_mode != "shuffled":
        print("WARNING: Fixed distribution sampler works best with shuffled data!")
    data_loader = get_data_loader(config)

    actions, next_states = [], []
    num_transitions = 0
    for batch in data_loader:
        num_transitions += len(batch[SampleBatch.OBS])
        actions.append(batch[SampleBatch.ACTIONS])
        next_states.append(batch[SampleBatch.NEXT_OBS])
        if num_transitions > config.transition_sampler.fixed_distribution.num_transitions_per_state:
            break

    if num_transitions < config.transition_sampler.fixed_distribution.num_transitions_per_state:
        print("WARNING: Fixed distribution transition sampler did not collect requested number of samples.")

    actions = torch.cat(actions)
    actions = actions[:config.transition_sampler.fixed_distribution.num_transitions_per_state]
    actions = actions.to(config.common.device)
    next_states = torch.cat(next_states)
    next_states = next_states[:config.transition_sampler.fixed_distribution.num_transitions_per_state]
    next_states = next_states.to(config.common.device)
    transition_sampler = FixedDistributionTransitionSampler(actions, next_states)
    return transition_sampler


def get_transition_sampler(env, config):
    if config.transition_sampler.type == "fixed_distribution":
        return get_fixed_distribution_transition_sampler(config)
    elif config.transition_sampler.type == "bouncing_balls_env_constant_velocity":
        return BouncingBallsEnvConstantVelocityTransitionSampler(dt=env.dt)
    elif config.transition_sampler.type == "bouncing_balls_env_stochastic_ego":
        return BouncingBallsEnvActionSamplingTransitionSampler(
            dt=env.dt,
            action_sampler=UniformlyRandomActionSampler(
                num_actions=config.transition_sampler.stochastic_ego.num_actions,
                max_magnitude=env.max_accel_magnitude,
            ),
        )
    elif config.transition_sampler.type == "bouncing_balls_env_boundary":
        return BouncingBallsEnvActionSamplingTransitionSampler(
            dt=env.dt,
            action_sampler=BoundaryActionSampler(max_magnitude=env.max_accel_magnitude),
        )
    elif config.transition_sampler.type == "bouncing_balls_env_linear":
        return BouncingBallsEnvActionSamplingTransitionSampler(
            dt=env.dt,
            action_sampler=LinearActionSampler(
                num_actions_each_dim=config.transition_sampler.bouncing_balls_env_linear.num_actions_each_dim,
                max_magnitude=env.max_accel_magnitude,
            ),
        )
    elif config.transition_sampler.type == "line_env_uniform_policy":
        return LineEnvUniformPolicyTransitionSampler(side_length=env.side_length)
    elif config.transition_sampler.type == "point_maze_env_linear":
        return PointMazeEnvActionSamplingTransitionSampler(
            dt=env.dt,
            action_sampler=LinearActionSampler(
                num_actions_each_dim=config.transition_sampler.point_maze_env_linear.num_actions_each_dim,
                max_magnitude=1.0,
            ),
        )
    elif config.transition_sampler.type == "mujoco_sim":
        # TODO(redacted): Check that max magnitude for all mujoco envs is 1.
        assert len(env.action_space.shape) == 1
        ndim = env.action_space.shape[0]
        action_sampler = LinearActionSampler(
            num_actions_each_dim=config.transition_sampler.mujoco_sim.num_actions_each_dim,
            max_magnitude=1.0,
            ndim=ndim,
        )
        make_env_fn = partial(get_env, env_name=config.rewards.env_name)
        return MujocoTransitionSampler(
            make_env_fn=make_env_fn,
            action_sampler=action_sampler,
            num_workers=config.transition_sampler.mujoco_sim.num_workers,
        )
    elif config.transition_sampler.type == "none":
        assert config.epic.skip_canonicalization
        return None
    else:
        raise ValueError(f"Invalid transition sampler type: {config.type}")


def get_canonical_rewards(data_loader, reward_models, transition_sampler, config):
    with torch.no_grad():
        epic = EPIC(
            batch_size=config.epic.batch_size,
            skip_canonicalization=config.epic.get("skip_canonicalization", False),
        )
        canonical_rewards = RewardCollection()
        if len(data_loader) > 1:
            assert config.epic.total_mean_mode in [
                "none",
                "conditional_per_state_in_distribution",
                "conditional_per_state_out_of_distribution",
                "conditional_per_state_out_of_distribution_quick",
            ], "Total mean subtraction invalid for multiple batches"
            assert config.epic.should_normalize_scale is False, "Scale normalization invalid for multiple batches"
        for batch in tqdm.tqdm(iterable=data_loader):
            states, actions, next_states, terminals = (
                batch[SampleBatch.OBS].to(config.common.device),
                batch[SampleBatch.ACTIONS].to(config.common.device),
                batch[SampleBatch.NEXT_OBS].to(config.common.device),
                batch[SampleBatch.DONES].to(config.common.device),
            )
            batch_canonical_rewards = epic.compute_canonical_rewards(
                reward_models,
                states,
                actions,
                next_states,
                terminals,
                transition_sampler,
                discount=config.epic.discount,
                total_mean_mode=config.epic.total_mean_mode,
                should_normalize_scale=config.epic.should_normalize_scale,
            )
            canonical_rewards.append(batch_canonical_rewards)
    return canonical_rewards


def compute_distances(rewards, config):
    if config.epic.distance == "pearson_distance":
        distance_matrix = compute_distance_between_reward_pairs(rewards, compute_pearson_distance)
    elif config.epic.distance == "direct_distance":
        # (I believe but am not sure that) direct distance only makes sense with scale normalized rewards.
        # That's probably also the case with translation, but that seems harder to enforce. You might try to mean
        # subtract the rewards, but potentially changes the optimal policy, so that seems like a bad idea.
        rewards = compute_scale_normalized_rewards(rewards)
        distance_matrix = compute_distance_between_reward_pairs(rewards, compute_direct_distance)
    else:
        raise ValueError(f"Invalid distance: {config.epic.distance}")
    return distance_matrix


@hydra.main(config_path="configs")
def main(config: dictconfig.DictConfig) -> None:
    """Runs reward evaluation on a gym environment.
    
    Call this by passing in a specific config file like:
    `python run_gym_reward_evaluation.py +gym=line_env`
    Replace `line_env` with the name of the config file you want to use, e.g., `bouncing_balls_env`.
    """
    os.makedirs(config.visualization.output_dir, exist_ok=True)

    data_loader = get_data_loader(config)
    env = get_env(config.rewards.env_name)
    reward_models = get_reward_models(env.observation_space, env.action_space, config)
    transition_sampler = get_transition_sampler(env, config)
    rewards = get_canonical_rewards(data_loader, reward_models, transition_sampler, config)
    distances = compute_distances(rewards, config)
    distances.visualize(
        os.path.join(config.visualization.output_dir, "distances.png"),
        title=f"Reward Distances for {config.rewards.env_name}",
    )
    distances.save(os.path.join(config.visualization.output_dir, "distances.pkl"))


if __name__ == "__main__":
    # pylint: disable=no-value-for-parameter
    main()
