import argparse
import os
from functools import partial as bind

import cv2
import gymnasium
import numpy as np
import ruamel.yaml as yaml
from gymnasium import spaces

import dreamerv3
import embodied
from embodied.envs.from_gymnasium import FromGymnasium
import mani_skill.envs


class StateWrapper(gymnasium.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self._observation_space = spaces.Box(-np.inf, np.inf, (env.observation_space.shape[1],), dtype=np.float32)

    @staticmethod
    def _unravel(step_result):
        unravel_result = [x[0] if hasattr(x, '__len__') else x for x in step_result[:-1]]
        info = {key: value[0] if hasattr(value, '__len__') else value for key, value in step_result[-1].items()}
        if 'success' in info:
            info['is_success'] = info['success']

        unravel_result.append(info)

        return unravel_result

    def reset(self, *args, **kwargs):
        return self._unravel(self.env.reset(*args, **kwargs))

    def step(self, *args, **kwargs):
        return self._unravel(self.env.step(*args, **kwargs))


class RgbdWrapper(gymnasium.Wrapper):
    def __init__(self, env, image_size):
        super().__init__(env)
        self._image_size = image_size
        self._observation_space = spaces.Box(0, 255, (self._image_size, self._image_size, 3), dtype=np.uint8)

    @staticmethod
    def _unravel(step_result):
        unravel_result = [step_result[0]['sensor_data']['base_camera']['rgb'][0]]
        unravel_result += [x[0] if hasattr(x, '__len__') else x for x in step_result[1:-1]]
        info = {key: value[0] if hasattr(value, '__len__') else value for key, value in step_result[-1].items()}
        if 'success' in info:
            info['is_success'] = info['success']

        unravel_result.append(info)

        return unravel_result

    def _resize(self, observation):
        return cv2.resize(observation.numpy(), dsize=(self._image_size, self._image_size), interpolation=cv2.INTER_CUBIC)

    def reset(self, *args, **kwargs):
        obs, info = self._unravel(self.env.reset(*args, **kwargs))
        return self._resize(obs), info

    def step(self, *args, **kwargs):
        obs, reward, terminated, truncated, info = self._unravel(self.env.step(*args, **kwargs))
        return self._resize(obs), reward, terminated, truncated, info


def make(config, task, obs_mode, image_size, seed):
    assert obs_mode in ('rgbd', 'state'), f'Unexpected obs_mode={obs_mode}'
    kwargs = dict(id=task, obs_mode=obs_mode, control_mode='pd_joint_delta_pos', render_mode='rgb_array')
    if obs_mode == 'rgbd':
        kwargs['sensor_configs'] = dict(width=224, height=224)

    env = gymnasium.make(**kwargs)
    if obs_mode == 'state':
        env = StateWrapper(env)
    else:
        env = RgbdWrapper(env, image_size)

    env.reset(seed=seed)
    env = FromGymnasium(env, obs_key='vector' if obs_mode == 'state' else 'image')
    env = dreamerv3.wrap_env(env, config)

    return env


def make_batch_env(config, task, obs_mode, image_size, seed, n_envs, parallel=None):
    def _make(env_seed=None):
        return lambda: make(config, task, obs_mode, image_size, env_seed)

    envs = []
    if parallel is None:
        parallel = config.envs.parallel

    for i in range(n_envs):
        env_fn = _make(seed + i)
        if parallel != 'none':
            env_fn = bind(embodied.Parallel, env_fn, parallel)

        envs.append(env_fn)

    envs = [env_fn() for env_fn in envs]
    return embodied.BatchEnv(envs, parallel=(parallel != 'none'))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, required=True)
    parser.add_argument('--task', type=str, required=True)
    parser.add_argument('--obs_mode', type=str, choices=['rgbd', 'state'], required=True)
    parser.add_argument('--image_size', type=int, required=False)
    parser.add_argument('--n_envs', default=1, type=int)
    parser.add_argument('--n_eval_envs', default=1, type=int)
    parser.add_argument('--checkpoint_path', type=str, required=False)
    parser.add_argument('--param_prefixes', nargs='*', type=str, required=False)
    return parser.parse_args()


def main():
    args = parse_args()
    print('Args:', vars(args))
    checkpoint_path = args.checkpoint_path
    param_prefixes = args.param_prefixes
    configs = yaml.YAML(typ='safe').load(embodied.Path(args.config_path).read())
    config = embodied.Config(configs['defaults'])
    config = config.update(dreamerv3.configs['medium'])

    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    outputs = [
        embodied.logger.TerminalOutput(),
        embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
    ]
    if config.wandb.project is not None:
        os.makedirs(config.wandb.dir, exist_ok=True)
        outputs.append(embodied.logger.WandBOutput(config))

    logger = embodied.Logger(step, outputs)
    env = make_batch_env(config, args.task, args.obs_mode, args.image_size, config['seed'], n_envs=args.n_envs, parallel='none' if args.n_eval_envs == 1 else 'process')
    eval_env = make_batch_env(config, args.task, args.obs_mode, args.image_size, config['seed'] + args.n_envs, n_envs=args.n_eval_envs,
                              parallel='none' if args.n_eval_envs == 1 else 'process')

    agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
    replay = embodied.replay.Uniform(
        config.batch_length, config.replay_size, logdir / 'replay')
    args = embodied.Config(
        **config.run, logdir=config.logdir,
        batch_steps=config.batch_size * config.batch_length)
    if checkpoint_path is None:
        checkpoint_path = os.path.join(config.wandb.project, config.wandb.name)
    os.makedirs(checkpoint_path, exist_ok=True)
    embodied.run.train_eval_no_replay(agent, train_env=env, eval_env=eval_env, train_replay=replay, logger=logger,
                                      args=args, checkpoint_path=checkpoint_path, param_prefixes=param_prefixes)


if __name__ == "__main__":
    main()
