import random
import numpy as np

import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import RenderFrame

from src.utils.misc import set_seed
from gymnasium import utils
from gymnasium.envs.mujoco import MuJocoPyEnv
from src.envs.mujoco_params import EnvParamsWrapper


class Walker2dParamsEnv(gym.Env):
    def __init__(self, goal_pos=None):
        self._env = EnvParamsWrapper(gym.make("Walker2d-v3", render_mode="rgb_array"))
        if goal_pos is None:
            goal_pos = self.sample_task()
        self.set_task(goal_pos)
        self.task_dim = 1
        self.observation_space = self._env.observation_space
        self.action_space = self._env.action_space

    def step(self, action):
        return self._env.step(action)

    def set_task(self, task):
        self._env.apply_params(task)
        self.goal = task

    def get_task(self):
        return self.goal

    @staticmethod
    def sample_task():
        log_scale_limit = 3.0
        vanilla_env = EnvParamsWrapper(gym.make("Walker2d-v3"))
        params = vanilla_env.get_params()
        # print(params)
        new_params = {}
        for k in params:
            if type(params[k]) is not float:
                size = params[k].shape
            else:
                size = 1
            multiplyers = np.array(1.5) ** np.random.uniform(-log_scale_limit, log_scale_limit, size=size)
            new_params[k] = params[k] * multiplyers
        # print(new_params)
        return new_params

    def reset_task(self, task):
        if task is None:
            task = self.sample_task()
        self.set_task(task)
        self.reset()

    def reset(self, seed=0, options=None):
        obs, info = self._env.reset(seed=seed, options=options)
        # print("Obs space", self.observation_space)
        # print("Obs shape", obs.shape)
        return obs, info
        # return self.reset_model(), {}

    def pos_to_state(self, arg):
        return tuple(np.concatenate([np.array([arg[k]]).reshape(1, 1) if type(arg[k]) is float else arg[k].reshape(-1) for k in arg]).tolist())

    def render(self) -> RenderFrame | list[RenderFrame] | None:
        return self._env.render()


def train_test_goals_walkp(num_test_goals, seed):
    num_train_goals = 1000
    set_seed(seed)
    train_goals = [Walker2dParamsEnv.sample_task() for _ in range(num_train_goals)]
    set_seed(seed + 1)
    test_goals = [Walker2dParamsEnv.sample_task() for _ in range(num_test_goals)]
    # test_goals_ood = np.array([[random.uniform(4.0, 7.0)] for _ in range(num_test_goals)])
    return train_goals, test_goals


if __name__ == "__main__":
    import gymnasium as gym
    gym.register(
        'Walker2dParams-v0',
        entry_point='src.envs.walker_params:Walker2dParamsEnv',
        max_episode_steps=200,
        kwargs={},
    )

    goal = None
    env = gym.make("Walker2dParams-v0", goal_pos=goal)
    obs, _ = env.reset()
    import matplotlib.pyplot as plt
    frame = env.render()
    # print(frame)
    # plt.imshow(frame.copy())
    # plt.show()
    # print(env.pos_to_state(env.goal))
    print(train_test_goals_walkp(10, 0)[1])