from functools import partial
from multiprocessing import cpu_count
import time

import pytest
import torch

from offline_rl.envs.custom_reacher_env import CustomReacherEnv
from offline_rl.rewards.evaluation.mujoco_transition_sampler import MujocoTransitionSampler
from offline_rl.rewards.evaluation.transition_sampler import LinearActionSampler


def sample_reacher_v3_states(num_states, dtype=torch.float32):
    env = CustomReacherEnv()
    states = torch.tensor([env.reset() for _ in range(num_states)], dtype=dtype)

    make_env_fn = partial(CustomReacherEnv)
    return make_env_fn, states


class TestMujocoTransitionSampler:
    @pytest.mark.parametrize("num_states", [10, 1])
    def test_sample(self, num_states):
        make_env_fn, states = sample_reacher_v3_states(num_states)
        state_dim = states.shape[-1]

        max_magnitude = 1.0
        action_dim = 2
        action_sampler = LinearActionSampler(2, max_magnitude, action_dim)

        sampler = MujocoTransitionSampler(make_env_fn, action_sampler, num_workers=num_states)
        actions, next_states, weights = sampler.sample(states)

        assert actions.shape == (num_states, action_sampler.num_actions, action_dim)
        assert next_states.dtype == states.dtype
        assert next_states.shape == (num_states, action_sampler.num_actions, state_dim)
        assert weights.shape == (num_states, action_sampler.num_actions)

    # TODO(redacted): Make it so that you can run this profiling test from the command line.
    @pytest.mark.skip
    def test_sample_timing(self):
        num_states = 1000
        make_env_fn, states = sample_reacher_v3_states(num_states)

        action_sampler = LinearActionSampler(num_actions_each_dim=2, max_magnitude=1.0, ndim=3)
        sampler = MujocoTransitionSampler(make_env_fn, action_sampler, num_workers=0)

        num_runs = 1
        start = time.time()
        for _ in range(num_runs):
            sampler.sample(states)
        end = time.time()
        print(f"\nTook {(end - start) / num_runs:0.4f} seconds on average to step {num_states} states.")
