from rlpyt.samplers.serial.sampler import SerialSampler
from rlpyt.envs.atari.atari_env import AtariTrajInfo
from rlpyt.utils.logging.context import logger_context

import wandb
import torch
import numpy as np
import os
from src.models_simple_equivariant import SimpleEquivariantCatDqnModel
from src.models_symmetric_MDP import SymmetricMDPCatDqnModel

from src.rlpyt_utils import OneToOneSerialEvalCollector, SerialSampler, MinibatchRlEvalWandb
from src.algos import GroupCategoricalDQN
from src.agent import SPRAgent
from src.rlpyt_atari_env import AtariEnv
from src.utils import set_config


def build_and_train(game="pong", run_ID=0, cuda_idx=0, args=None):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    env = AtariEnv
    config = set_config(args, game)

    sampler = SerialSampler(
        EnvCls=env,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=config["env"],
        eval_env_kwargs=config["eval_env"],
        batch_T=config['sampler']['batch_T'],
        batch_B=config['sampler']['batch_B'],
        max_decorrelation_steps=0,
        eval_CollectorCls=OneToOneSerialEvalCollector,
        eval_n_envs=config["sampler"]["eval_n_envs"],
        eval_max_steps=config['sampler']['eval_max_steps'],
        eval_max_trajectories=config["sampler"]["eval_max_trajectories"],
    )
    args.discount = config["algo"]["discount"]
    algo = GroupCategoricalDQN(optim_kwargs=config["optim"], jumps=args.jumps, **config["algo"])  # Run with defaults.

    if args.model_type == 'Simple-Equivariant':
        model = SimpleEquivariantCatDqnModel
    elif args.model_type == 'Symmetric-MDP':
        model = SymmetricMDPCatDqnModel
    agent = SPRAgent(ModelCls=model, model_kwargs=config["model"], **config["agent"])

    wandb.config.update(config)
    runner = MinibatchRlEvalWandb(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=args.n_steps,
        affinity=dict(cuda_idx=cuda_idx),
        log_interval_steps=args.n_steps//args.num_logs,
        seed=args.seed,
        final_eval_only=args.final_eval_only,
    )
    config = dict(game=game)
    name = "dqn_" + game
    log_dir = "logs"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()

    quit()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--game', help='Atari game', default='ms_pacman')
    parser.add_argument('--model-type', help='Type of model 1)Simple-Equivariant 2)Symmetric-MDP',
                        choices=["Simple-Equivariant", "Symmetric-MDP"], default='Simple-Equivariant')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--grayscale', type=int, default=1)
    parser.add_argument('--framestack', type=int, default=4)
    parser.add_argument('--imagesize', type=int, default=84)
    parser.add_argument('--n-steps', type=int, default=100000)
    parser.add_argument('--dqn-hidden-size', type=int, default=256)
    parser.add_argument('--target-update-interval', type=int, default=1)
    parser.add_argument('--target-update-tau', type=float, default=1.)
    parser.add_argument('--momentum-tau', type=float, default=1.)
    parser.add_argument('--batch-b', type=int, default=1)
    parser.add_argument('--batch-t', type=int, default=1)
    parser.add_argument('--beluga', action="store_true")
    parser.add_argument('--jumps', type=int, default=5)
    parser.add_argument('--num-logs', type=int, default=10)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--replay-ratio', type=int, default=64)
    parser.add_argument('--dynamics-blocks', type=int, default=0)
    parser.add_argument('--n-step', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--tag', type=str, default='', help='Tag for wandb run.')
    parser.add_argument('--wandb-dir', type=str, default='', help='Directory for wandb files.')
    parser.add_argument('--norm-type', type=str, default='bn', choices=["bn", "ln", "in", "none"], help='Normalization')
    parser.add_argument('--aug-prob', type=float, default=1., help='Probability to apply augmentation')
    parser.add_argument('--dropout', type=float, default=0., help='Dropout probability in convnet.')
    parser.add_argument('--distributional', type=int, default=1)
    parser.add_argument('--delta-clip', type=float, default=1., help="Huber Delta")
    parser.add_argument('--prioritized-replay', type=int, default=1)
    parser.add_argument('--momentum-encoder', type=int, default=1)
    parser.add_argument('--shared-encoder', type=int, default=0)
    parser.add_argument('--noisy-nets', type=int, default=1)
    parser.add_argument('--noisy-nets-std', type=float, default=0.5)
    parser.add_argument('--projection-type', type=str, default='q_l1', choices=["mlp", "q_shared", "linear", "none"], help='Style of NCE classifier')
    parser.add_argument('--second-projection-type', type=str, default='mlp', choices=["mlp", "q_shared", "linear", "none"], help='Style of NCE classifier')
    parser.add_argument('--augmentation', type=str, default=["shift", "intensity"], nargs="+",
                        choices=["none", "rrc", "affine", "crop", "blur", "shift", "intensity"],
                        help='Style of augmentation')
    parser.add_argument('--q-shared-type', type=str, default=["value", "advantage"], nargs="+",
                        choices=["noisy", "value", "advantage", "relu"],
                        help='Style of q_l1 projection')
    parser.add_argument('--target-augmentation', type=int, default=1, help='Use augmentation on inputs to target networks')
    parser.add_argument('--eval-augmentation', type=int, default=0, help='Use augmentation on inputs at evaluation time')
    parser.add_argument('--reward-loss-weight', type=float, default=0.)
    parser.add_argument('--model-rl-weight', type=float, default=0.)
    parser.add_argument('--model-group-weight', type=float, default=5.)
    parser.add_argument('--symmetricMDP-loss-weight', type=float, default=0.)
    parser.add_argument('--eps-steps', type=int, default=2001)
    parser.add_argument('--min-steps-learn', type=int, default=2000)
    parser.add_argument('--eps-init', type=float, default=1.)
    parser.add_argument('--eps-final', type=float, default=0.)
    parser.add_argument('--final-eval-only', type=int, default=0)
    parser.add_argument('--time-offset', type=int, default=0)
    parser.add_argument('--project', type=str, default="EquivariantRepRL")
    parser.add_argument('--entity', type=str, default="random")
    parser.add_argument('--cuda_idx', help='gpu to use ', type=int, default=0)
    parser.add_argument('--max-grad-norm', type=float, default=10., help='Max Grad Norm')
    parser.add_argument('--public', action='store_true', help='If set, uses anonymous wandb logging')

    # Group arguments
    parser.add_argument('--group-type', type=str, default='En',choices=["SOn", "En", "GLn", "Real", "none"],
                        help='Group type options 1)SOn 2)En 3)GLn 4)Real 5)none')
    parser.add_argument('--group-dim', type=int, default=2)
    parser.add_argument('--num-blocks', type=int, default=16)
    parser.add_argument('--parameterization-type', type=str, default='Euler',
                        help='Parametrization type options 1)Lie 2)Euler')
    parser.add_argument('--only-action-transition', type=int, default=0)
    args = parser.parse_args()
    os.environ["WANDB_MODE"] = "offline"
    if args.public:
        wandb.init(anonymous="allow", config=args, tags=[args.tag] if args.tag else None, dir=args.wandb_dir)
    else:
        wandb.init(project=args.project, entity=args.entity, config=args, tags=[args.tag] if args.tag else None, dir=args.wandb_dir)

    wandb.config.update(vars(args))
    build_and_train(game=args.game,
                    cuda_idx=args.cuda_idx,
                    args=args)
