import argparse
import os
import random

import gymnasium as gym
import numpy as np
import ruamel.yaml as yaml
from omegaconf import OmegaConf

import dreamerv3
import embodied
from cw_envs import CwTargetEnv
from embodied.envs.from_gymnasium import FromGymnasium


def make_causal_world(config, env_config_path, seed):
    env_config = OmegaConf.load(env_config_path)
    random.seed(seed)
    np.random.seed(seed)
    env = CwTargetEnv(env_config, seed)
    env.action_space.seed(seed)
    env = FromGymnasium(env, obs_key='image')  # Or obs_key='vector'.
    env = dreamerv3.wrap_env(env, config)
    env = embodied.BatchEnv([env], parallel=False)

    return env


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, required=True)
    parser.add_argument('--checkpoint_path', type=str, required=False)
    parser.add_argument('--param_prefixes', nargs='*', type=str, required=False)
    return parser.parse_args()


def main():
    args = parse_args()
    print('Args:', vars(args))
    checkpoint_path = args.checkpoint_path
    param_prefixes = args.param_prefixes
    configs = yaml.YAML(typ='safe').load(embodied.Path(args.config_path).read())
    config = embodied.Config(configs['defaults'])
    config = config.update(dreamerv3.configs['medium'])
    config = config.update({
        'run.train_ratio': 64,
        'run.log_every': 30,  # Seconds
        'batch_size': 16,
        'jax.prealloc': False,
        'encoder.mlp_keys': '$^',
        'decoder.mlp_keys': '$^',
        'encoder.cnn_keys': 'image',
        'decoder.cnn_keys': 'image',
        # 'jax.platform': 'cpu',
    })
    env_config_path = 'cw_envs/config/reaching-hard_orig.yaml'
    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    logger = embodied.Logger(step, [
        embodied.logger.TerminalOutput(),
        embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
        embodied.logger.WandBOutput(config),
        # embodied.logger.MLFlowOutput(logdir.name),
    ])
    env = make_causal_world(config=config, env_config_path=env_config_path, seed=config['seed'])
    eval_env = make_causal_world(config=config, env_config_path=env_config_path, seed=config['seed'] + 1)

    agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
    replay = embodied.replay.Uniform(
        config.batch_length, config.replay_size, logdir / 'replay')
    args = embodied.Config(
        **config.run, logdir=config.logdir,
        batch_steps=config.batch_size * config.batch_length)
    if checkpoint_path is None:
        checkpoint_path = os.path.join(config.wandb.project, config.wandb.name)
    os.makedirs(checkpoint_path, exist_ok=True)
    embodied.run.train_eval_no_replay(agent, train_env=env, eval_env=eval_env, train_replay=replay, logger=logger,
                                      args=args, checkpoint_path=checkpoint_path, param_prefixes=param_prefixes)


if __name__ == "__main__":
    main()
