import argparse
from functools import partial as bind

import cv2
import gymnasium as gym
import numpy as np
import ruamel.yaml as yaml

import dreamerv3
import embodied
from embodied.envs.from_gymnasium import FromGymnasium
from robosuite_envs.wrapper import RobosuiteEnv


class ResizeWrapper(gym.Wrapper):
    def __init__(self, env, image_size, store_source_observation=False):
        super().__init__(env)
        self._image_size = image_size
        self._store_source_observations = store_source_observation
        assert len(env.observation_space.shape) == 3
        assert env.observation_space.shape[2] == 3
        assert env.observation_space.dtype == np.uint8
        self._observation_space = gym.spaces.Box(
            low=0, high=255, shape=(self._image_size, self._image_size, 3), dtype=np.uint8
        )

    def reset(self, *, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        return self.observation(obs, info), info

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        return self.observation(observation, info), reward, terminated, truncated, info

    def observation(self, observation, info):
        if self._store_source_observations:
            info['source_observation'] = observation

        return cv2.resize(observation, dsize=(self._image_size, self._image_size), interpolation=cv2.INTER_CUBIC)


def make_robosuite(config, task, horizon, image_size, seed, initialization_noise_magnitude=None,
                   use_random_object_position=False, store_source_observation=False):
    env = RobosuiteEnv(task, horizon, seed, initialization_noise_magnitude, use_random_object_position)
    env = ResizeWrapper(env, image_size, store_source_observation=store_source_observation)
    env = FromGymnasium(env, obs_key='image')
    env = dreamerv3.wrap_env(env, config)

    return env


def make_robosuite_lambda(config, task, horizon, image_size, seed=None, initialization_noise_magnitude=None,
                          use_random_object_position=False, store_source_observation=False):
    return lambda: make_robosuite(config, task, horizon, image_size, seed, initialization_noise_magnitude,
                                  use_random_object_position, store_source_observation)


def make_batch_env(task, horizon, seed, image_size, config, n_envs, initialization_noise_magnitude=None,
                   use_random_object_position=False, store_source_observation=False, parallel=None):
    envs = []
    if parallel is None:
        parallel = config.envs.parallel

    for i in range(n_envs):
        env_fn = make_robosuite_lambda(config, task, horizon, image_size, seed=seed + i,
                                       initialization_noise_magnitude=initialization_noise_magnitude,
                                       use_random_object_position=use_random_object_position,
                                       store_source_observation=store_source_observation)
        if parallel != 'none':
            env_fn = bind(embodied.Parallel, env_fn, parallel)

        envs.append(env_fn)

    envs = [env_fn() for env_fn in envs]
    return embodied.BatchEnv(envs, parallel=(parallel != 'none'))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, required=True)
    parser.add_argument('--task', type=str, required=True)
    parser.add_argument('--horizon', type=int, required=True)
    parser.add_argument('--image_size', type=int, required=True)
    parser.add_argument('--n_envs', default=1, type=int)
    parser.add_argument('--initialization_noise_magnitude', type=float, required=False)
    parser.add_argument('--use_random_object_position', choices=['large', 'medium', 'small'], required=True)
    parser.add_argument('--checkpoint_path', type=str, required=True)
    parser.add_argument('--dataset_size', type=int, required=True)
    parser.add_argument('--n_samples', type=int, required=True)
    parser.add_argument('--output_path', type=str, required=True)
    parser.add_argument('--max_workers', type=int, default=3)
    return parser.parse_args()


def main():
    args = parse_args()
    print('Args:', vars(args))
    configs = yaml.YAML(typ='safe').load(embodied.Path(args.config_path).read())
    config = embodied.Config(configs['defaults'])
    config = config.update(dreamerv3.configs['medium'])

    step = embodied.Counter()
    env = make_batch_env(args.task, args.horizon, image_size=args.image_size, config=config, seed=config['seed'],
                         n_envs=args.n_envs, initialization_noise_magnitude=args.initialization_noise_magnitude,
                         use_random_object_position=args.use_random_object_position, store_source_observation=True)

    agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
    config_args = embodied.Config(
        **config.run, logdir=config.logdir, seed=config.seed,
        batch_steps=config.batch_size * config.batch_length)
    embodied.run.collect_observations(agent=agent, env=env, args=config_args, checkpoint_path=args.checkpoint_path,
                                      dataset_size=args.dataset_size, n_samples=args.n_samples,
                                      output_path=args.output_path, max_workers=args.max_workers)


if __name__ == "__main__":
    main()
