import argparse
from functools import partial as bind

import cv2
import gymnasium as gym
import numpy as np
import ruamel.yaml as yaml
from gymnasium import spaces

import dreamerv3
import embodied
from embodied.envs.from_gymnasium import FromGymnasium
import mani_skill.envs


class RgbdWrapper(gym.Wrapper):
    def __init__(self, env, image_size, store_source_observation=False):
        super().__init__(env)
        self._image_size = image_size
        self._store_source_observation = store_source_observation
        self._observation_space = spaces.Box(0, 255, (self._image_size, self._image_size, 3), dtype=np.uint8)

    @staticmethod
    def _unravel(step_result):
        unravel_result = [step_result[0]['sensor_data']['base_camera']['rgb'][0]]
        unravel_result += [x[0] if hasattr(x, '__len__') else x for x in step_result[1:-1]]
        info = {key: value[0] if hasattr(value, '__len__') else value for key, value in step_result[-1].items()}
        info['segmentation'] = step_result[0]['sensor_data']['base_camera']['segmentation'][0].numpy().astype(np.uint8)
        if 'success' in info:
            info['is_success'] = info['success']

        unravel_result.append(info)

        return unravel_result

    def _resize(self, observation, info):
        observation = observation.numpy()
        if self._store_source_observation:
            info['source_observation'] = observation

        return cv2.resize(observation, dsize=(self._image_size, self._image_size), interpolation=cv2.INTER_CUBIC)

    def reset(self, *args, **kwargs):
        obs, info = self._unravel(self.env.reset(*args, **kwargs))
        return self._resize(obs, info), info

    def step(self, *args, **kwargs):
        obs, reward, terminated, truncated, info = self._unravel(self.env.step(*args, **kwargs))
        return self._resize(obs, info), reward, terminated, truncated, info


def make(config, task, image_size, store_source_observation, seed):
    kwargs = dict(id=task, obs_mode='rgbd', control_mode='pd_joint_delta_pos', render_mode='rgb_array')
    kwargs['sensor_configs'] = dict(width=224, height=224)
    env = gym.make(**kwargs)
    env = RgbdWrapper(env, image_size, store_source_observation=store_source_observation)

    env.reset(seed=seed)
    env = FromGymnasium(env, obs_key='image')
    env = dreamerv3.wrap_env(env, config)

    return env


def make_batch_env(config, task, image_size, store_source_observation, seed, n_envs, parallel=None):
    def _make(env_seed=None):
        return lambda: make(config, task, image_size, store_source_observation, env_seed)

    envs = []
    if parallel is None:
        parallel = config.envs.parallel

    for i in range(n_envs):
        env_fn = _make(seed + i)
        if parallel != 'none':
            env_fn = bind(embodied.Parallel, env_fn, parallel)

        envs.append(env_fn)

    envs = [env_fn() for env_fn in envs]
    return embodied.BatchEnv(envs, parallel=(parallel != 'none'))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, required=True)
    parser.add_argument('--task', type=str, required=True)
    parser.add_argument('--image_size', type=int, required=True)
    parser.add_argument('--n_envs', default=1, type=int)
    parser.add_argument('--checkpoint_path', type=str, required=True)
    parser.add_argument('--dataset_size', type=int, required=True)
    parser.add_argument('--n_samples', type=int, required=False)
    parser.add_argument('--output_path', type=str, required=True)
    parser.add_argument('--max_workers', type=int, default=3)
    parser.add_argument('--greedy_epsilon', type=float, required=True)
    return parser.parse_args()


def main():
    args = parse_args()
    print('Args:', vars(args))
    configs = yaml.YAML(typ='safe').load(embodied.Path(args.config_path).read())
    config = embodied.Config(configs['defaults'])
    config = config.update(dreamerv3.configs['medium'])

    step = embodied.Counter()
    env = make_batch_env(config=config, task=args.task, image_size=args.image_size, store_source_observation=True, seed=config['seed'],
                         n_envs=args.n_envs, parallel='none' if args.n_envs == 1 else None)

    agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
    config_args = embodied.Config(
        **config.run, logdir=config.logdir, seed=config.seed,
        batch_steps=config.batch_size * config.batch_length)
    embodied.run.collect_observations(agent=agent, env=env, args=config_args, checkpoint_path=args.checkpoint_path,
                                      dataset_size=args.dataset_size, output_path=args.output_path,
                                      max_workers=args.max_workers, greedy_epsilon=args.greedy_epsilon,
                                      n_samples=args.n_samples)


if __name__ == "__main__":
    main()
