import argparse
import torch
import rlkit.torch.pytorch_util as ptu
import yaml
from rlkit.data_management.torch_replay_buffer import TorchReplayBuffer
from rlkit.envs import make_env
from rlkit.envs.vecenv import SubprocVectorEnv, VectorEnv
from rlkit.launchers.launcher_util import set_seed, setup_logger
from rlkit.samplers.data_collector import (VecMdpPathCollector, VecMdpStepCollector)
from rlkit.torch.dsac.dsac_reg import DSACTrainer_reg
from rlkit.torch.dsac.networks import CategoricalMlp, softmax
from rlkit.torch.networks import FlattenMlp
from rlkit.torch.sac.policies import MakeDeterministic, TanhGaussianPolicy
from rlkit.torch.torch_rl_algorithm import TorchVecOnlineRLAlgorithm

torch.set_num_threads(4)
torch.set_num_interop_threads(4)


def experiment(variant):
    # (1) basic setting:
    dummy_env = make_env(variant['env'])
    obs_dim = dummy_env.observation_space.low.size
    action_dim = dummy_env.action_space.low.size
    expl_env = VectorEnv([lambda: make_env(variant['env']) for _ in range(variant['expl_env_num'])])
    expl_env.seed(variant["seed"])
    expl_env.action_space.seed(variant["seed"])
    eval_env = SubprocVectorEnv([lambda: make_env(variant['env']) for _ in range(variant['eval_env_num'])])
    eval_env.seed(variant["seed"])

    M = variant['layer_size'] # hidden layer size
    num_atoms = variant['num_atoms'] # atoms for C51

    # (2) define networks: 2 distributional networks
    zf1 = CategoricalMlp(
        input_size=obs_dim + action_dim, # represent s,a in advance
        num_atoms=num_atoms,
        hidden_sizes=[M, M],
    )
    zf2 = CategoricalMlp(
        input_size=obs_dim + action_dim,
        num_atoms=num_atoms,
        hidden_sizes=[M, M],
    )
    target_zf1 = CategoricalMlp(
        input_size=obs_dim + action_dim,
        num_atoms=num_atoms,
        hidden_sizes=[M, M],
    )
    target_zf2 = CategoricalMlp(
        input_size=obs_dim + action_dim,
        num_atoms=num_atoms,
        hidden_sizes=[M, M],
    )

    # continuous action
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    )
    eval_policy = MakeDeterministic(policy)

    target_policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    )
    # (3) quantiles fraction proposal network: fqf, qrdqn, iqn(random sampling)
    fp = target_fp = None
    if variant['trainer_kwargs'].get('tau_type') in ['C51', 'fqf']:
        fp = FlattenMlp(
            input_size=obs_dim + action_dim,
            output_size=num_atoms,
            hidden_sizes=[M // 2, M // 2],
            output_activation=softmax,
        )
        target_fp = FlattenMlp(
            input_size=obs_dim + action_dim,
            output_size=num_atoms,
            hidden_sizes=[M // 2, M // 2],
            output_activation=softmax,
        )
    eval_path_collector = VecMdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = VecMdpStepCollector(
        expl_env,
        policy,
    )
    replay_buffer = TorchReplayBuffer(
        variant['replay_buffer_size'],
        dummy_env,
    )
    # (4) define trainer with corresponding networks
    Flag_entropy = True if args.entropy == 1 else False
    trainer = DSACTrainer_reg(
        env=dummy_env,
        policy=policy,
        target_policy=target_policy,
        zf1=zf1,
        zf2=zf2,
        target_zf1=target_zf1,
        target_zf2=target_zf2,
        fp=fp,
        target_fp=target_fp,
        num_atoms=num_atoms, # C51 atoms
        Flag_entropy=Flag_entropy,
        gradient = args.gradient,
        v_max = args.vmax, # ablation study
        varepsilon=args.varepsilon,
        alg=args.alg,
        reg=args.reg,
        **variant['trainer_kwargs'],
    )
    algorithm = TorchVecOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'],
    )
    ########################## (5) train #################################
    algorithm.to(ptu.device)
    algorithm.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Interpret Distributional RL via Regularization')
    parser.add_argument('--config', type=str, default="ant")
    parser.add_argument('--alg', type=str, default="Entropy", help='C51, Entropy')
    parser.add_argument('--vmax', type=int, default=5000)
    parser.add_argument('--gpu', type=int, default=0, help="using cpu with -1")
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--entropy', type=int, default=1)
    parser.add_argument('--gradient', type=int, default=0)
    parser.add_argument('--varepsilon', type=float, default=1.0) # 1.0 for CE loss
    parser.add_argument('--reg', type=float, default=1.0) # 1.0 for CE loss

    args = parser.parse_args()
    alg = '_C51' if args.alg in ['C51', 'Entropy'] else ''
    with open('configs/dsac/'+args.config+alg+'.yaml', 'r', encoding="utf-8") as f:
        variant = yaml.load(f, Loader=yaml.FullLoader)
    variant["seed"] = args.seed

    grad = '' if args.gradient == 0 else '_grad'
    if args.entropy == 1:
        # if args.varepsilon == 1.0:  # C51
        #     log_prefix = "_".join(["dsac"+grad, variant["env"][:-3].lower(), str(variant["version"]), '-vmax'+str(args.vmax)])
        # else:
        if args.alg == 'C51':
            log_prefix = "_".join(["dsac" + grad, variant["env"][:-3].lower(), str(variant["version"]), '-vmax' + str(args.vmax), '-varepsilon'+str(args.varepsilon)])
        else: # entropy
            log_prefix = "_".join(["dsac" + grad, variant["env"][:-3].lower(), str(variant["version"]), '-vmax' + str(args.vmax),'-varepsilon'+str(args.varepsilon), '-ent'+str(args.reg)])
    else:
        # if args.varepsilon == 1.0:
        #     log_prefix = "_".join(["dsacNoent"+grad, variant["env"][:-3].lower(), str(variant["version"]), '-vmax'+str(args.vmax)])
        # else:
        if args.alg == 'C51':
            log_prefix = "_".join(["dsacNoent"+grad, variant["env"][:-3].lower(), str(variant["version"]), '-vmax'+str(args.vmax), '-varepsilon'+str(args.varepsilon)])
        else:
            log_prefix = "_".join(["dsacNoent"+grad, variant["env"][:-3].lower(), str(variant["version"]), '-vmax'+str(args.vmax), '-varepsilon'+str(args.varepsilon), '-ent'+str(args.reg)])
    setup_logger(log_prefix, variant=variant, seed=args.seed)
    if args.gpu >= 0:
        ptu.set_gpu_mode(True, args.gpu)
    set_seed(args.seed)
    experiment(variant)
