import argparse
import os
from functools import partial as bind

import gymnasium
import ruamel.yaml as yaml
from gymnasium import Wrapper
from gymnasium.wrappers import ResizeObservation

import dreamerv3
import embodied
from embodied.envs.from_gymnasium import FromGymnasium
from vizdoom import gymnasium_wrapper


class ImageExtractorWrapper(Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = env.observation_space['screen']

    @staticmethod
    def _separate_image_and_features(observation, info):
        info.update({key: value for key, value in observation.items() if key != 'screen'})
        return observation['screen'], info

    def reset(self, *, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        return self._separate_image_and_features(obs, info)

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        observation, info = self._separate_image_and_features(observation, info)
        return observation, reward, terminated, truncated, info


def make(config, task, image_size, seed):
    env = gymnasium.make(task, render_mode='rgb_array')
    env = ImageExtractorWrapper(env)
    env = ResizeObservation(env, image_size)
    env.reset(seed=seed)
    env = FromGymnasium(env, obs_key='image')
    env = dreamerv3.wrap_env(env, config)

    return env


def make_batch_env(config, task, image_size, seed, n_envs, parallel=None):
    def _make(env_seed=None):
        return lambda: make(config, task, image_size, env_seed)

    envs = []
    if parallel is None:
        parallel = config.envs.parallel

    for i in range(n_envs):
        env_fn = _make(seed + i)
        if parallel != 'none':
            env_fn = bind(embodied.Parallel, env_fn, parallel)

        envs.append(env_fn)

    envs = [env_fn() for env_fn in envs]
    return embodied.BatchEnv(envs, parallel=(parallel != 'none'))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, required=True)
    parser.add_argument('--task', type=str, required=True)
    parser.add_argument('--image_size', type=int, required=True)
    parser.add_argument('--n_envs', default=1, type=int)
    parser.add_argument('--n_eval_envs', default=1, type=int)
    parser.add_argument('--checkpoint_path', type=str, required=False)
    parser.add_argument('--param_prefixes', nargs='*', type=str, required=False)
    return parser.parse_args()


def main():
    args = parse_args()
    print('Args:', vars(args))
    checkpoint_path = args.checkpoint_path
    param_prefixes = args.param_prefixes
    configs = yaml.YAML(typ='safe').load(embodied.Path(args.config_path).read())
    config = embodied.Config(configs['defaults'])
    config = config.update(dreamerv3.configs['medium'])

    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    outputs = [
        embodied.logger.TerminalOutput(),
        embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
    ]
    if config.wandb.project is not None:
        os.makedirs(config.wandb.dir, exist_ok=True)
        outputs.append(embodied.logger.WandBOutput(config))

    logger = embodied.Logger(step, outputs)
    env = make_batch_env(config, args.task, args.image_size, config['seed'], args.n_envs)
    eval_env = make_batch_env(config, args.task, args.image_size, config['seed'] + args.n_envs, n_envs=args.n_eval_envs,
                              parallel='none' if args.n_eval_envs == 1 else 'process')

    agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
    replay = embodied.replay.Uniform(
        config.batch_length, config.replay_size, logdir / 'replay')
    args = embodied.Config(
        **config.run, logdir=config.logdir,
        batch_steps=config.batch_size * config.batch_length)
    if checkpoint_path is None:
        checkpoint_path = os.path.join(config.wandb.project, config.wandb.name)
    os.makedirs(checkpoint_path, exist_ok=True)
    embodied.run.train_eval_no_replay(agent, train_env=env, eval_env=eval_env, train_replay=replay, logger=logger,
                                      args=args, checkpoint_path=checkpoint_path, param_prefixes=param_prefixes)


if __name__ == "__main__":
    main()
