import argparse
import datetime
import os
from pprint import pprint

import random
import numpy as np
import torch
from vmoc.mujoco_env import make_mujoco_env
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import VectorReplayBuffer
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
# from vmoc.my_collector import MyCollector as Collector
from tianshou.data import Collector
from vmoc.vmoc_policy import VMOCPolicy
from utils.utils_config import ConfigDict, load_config, ModelDict
from utils.utils_ray import get_git_commit_hash

from torch import nn
from copy import deepcopy
# Used for policy exploration
from tianshou.exploration import GaussianNoise
# Option Nets
from tianshou.utils.net.discrete import Critic  # , Actor, IntrinsicCuriosityModule
# Actor Nets
from tianshou.utils.net.common import Net
from vmoc.vmoc_net import OptionNet, OptionActor, BiActorProb, TriNet, TriCritic  # TODO: , DoeContiCritic


def get_args():
  parser = argparse.ArgumentParser()
  parser.add_argument("--env", type=str, default="HalfCheetah-v4")
  parser.add_argument("--dmodel", type=int, default=None)
  parser.add_argument("--num_options", type=int, default=None)
  parser.add_argument("--hid_dim", type=int, default=None)
  parser.add_argument("--rand", type=bool, default=False)
  parser.add_argument("--device", type=str, default="cpu")
  return parser.parse_args()


def test_vmoc(args=get_args()):
  config: ConfigDict = load_config('./ts_vmoc_config.yaml')
  config.env = args.env

  if args.rand:
    config.seed = random.randint(0, 10000)
  if args.hid_dim is not None:
    config.hidden_sizes = [args.hid_dim, args.hid_dim]
  if args.num_options is not None:
    config.num_options = args.num_options
  if args.dmodel is not None:
    config.dmodel = args.dmodel
  if args.device is not None:
    config.device = args.device

  env, train_envs, test_envs = make_mujoco_env(
      config.env,
      config.seed,
      config.num_envs_per_worker,
      config.eval_num_envs_per_worker,
      obs_norm=False)
  config.state_shape = env.observation_space.shape or env.observation_space.n
  config.action_shape = env.action_space.shape or env.action_space.n
  config.max_action = env.action_space.high[0]
  print("Observations shape:", config.state_shape)
  print("Actions shape:", config.action_shape)
  print("Action range:", np.min(env.action_space.low),
        np.max(env.action_space.high))
  # seed
  random.seed(config.seed)
  np.random.seed(config.seed)
  torch.manual_seed(config.seed)
  #* Option
  # option_embeds only included train params in `OptionNet` `OptionActor` `VMOCPolicy`
  option_embeds = nn.Embedding(config.num_options,
                               config.dmodel).to(config.device)

  # P(o_t|s_t, o_{t-1}), Categorical
  netO = OptionNet(
      config.state_shape,
      config.dmodel,
      hidden_sizes=config.hidden_sizes,
      device=config.device,
      option_embeds=None if config.detach_pO_oembed else option_embeds,
  )
  pO = OptionActor(
      netO, config.num_options, device=config.device).to(config.device)
  pO_optim = torch.optim.Adam(pO.parameters(), lr=config.o_actor_lr)

  # Q_O(s_t, o_t) -> Q_O(s_t), but with all options
  net1 = Net(
      config.state_shape,
      hidden_sizes=config.hidden_sizes,
      device=config.device,
  )
  qO1 = Critic(
      net1, last_size=config.num_options,
      device=config.device).to(config.device)
  qO1_optim = torch.optim.Adam(qO1.parameters(), lr=config.o_critic_lr)
  net2 = Net(
      config.state_shape,
      hidden_sizes=config.hidden_sizes,
      device=config.device,
  )
  qO2 = Critic(
      net2, last_size=config.num_options,
      device=config.device).to(config.device)
  qO2_optim = torch.optim.Adam(qO2.parameters(), lr=config.o_critic_lr)

  #* Actor
  # P(a|s,o)
  net_a = Net(
      config.state_shape,
      config.dmodel,
      hidden_sizes=config.hidden_sizes,
      concat=True,
      device=config.device)
  pA = BiActorProb(
      net_a,
      config.action_shape,
      option_embeds=None if config.detach_pA_oembed else option_embeds,
      device=config.device,
  ).to(config.device)
  pA_optim = torch.optim.Adam(pA.parameters(), lr=config.a_actor_lr)

  # Q_A(s,a,o)
  net_c1 = TriNet(
      config.state_shape,
      config.action_shape,
      option_shape=config.dmodel,
      hidden_sizes=config.hidden_sizes,
      concat=True,
      device=config.device)
  net_c2 = TriNet(
      config.state_shape,
      config.action_shape,
      option_shape=config.dmodel,
      hidden_sizes=config.hidden_sizes,
      concat=True,
      device=config.device)
  qA1 = TriCritic(
      net_c1,
      option_embeds=None if config.detach_qA_oembed else option_embeds,
      device=config.device,
  ).to(config.device)
  qA1_optim = torch.optim.Adam(qA1.parameters(), lr=config.a_critic_lr)
  qA2 = TriCritic(
      net_c2,
      option_embeds=None if config.detach_qA_oembed else option_embeds,
      device=config.device,
  ).to(config.device)
  qA2_optim = torch.optim.Adam(qA2.parameters(), lr=config.a_critic_lr)

  if config.auto_alpha:
    config.a_target_entropy = -np.prod(env.action_space.shape)
    config.a_log_alpha = torch.tensor([config.a_P_alpha],
                                      requires_grad=True,
                                      device=config.device).to(config.device)
    config.a_alpha_optim = torch.optim.Adam([config.a_log_alpha],
                                            lr=config.a_alpha_lr)

    config.o_target_entropy = 0.98 * np.log(config.num_options)
    config.o_log_alpha = torch.tensor([config.o_P_alpha],
                                      requires_grad=True,
                                      device=config.device).to(config.device)
    config.o_alpha_optim = torch.optim.Adam([config.o_log_alpha],
                                            lr=config.o_alpha_lr)

  actor_dict = ModelDict({
      'oL0': {
          'probO': pA,
          'probO_optim': pA_optim
      },
      'oL1': {
          'probO': pO,
          'probO_optim': pO_optim
      },
  })
  critic_dict = ModelDict({
      'oL0': {
          'qO1': qA1,
          'qO1_old': deepcopy(qA1),
          'qO1_optim': qA1_optim,
          'qO2': qA2,
          'qO2_old': deepcopy(qA2),
          'qO2_optim': qA2_optim
      },
      'oL1': {
          'qO1': qO1,
          'qO1_old': deepcopy(qO1),
          'qO1_optim': qO1_optim,
          'qO2': qO2,
          'qO2_old': deepcopy(qO2),
          'qO2_optim': qO2_optim
      }
  })
  #* Policy
  policy = VMOCPolicy(
      actor_dict,
      critic_dict,
      tau=config.tau,
      gamma=config.gamma,
      estimation_step=config.estimation_step,
      action_space=env.action_space,
      option_embeds=option_embeds,
      exploration_noise=GaussianNoise(
          sigma=0.2) if config.exploration_noise else None,
      config=config)

  # load a previous policy
  if config.resume_path:
    policy.load_state_dict(
        torch.load(config.resume_path, map_location=config.device))
    print("Loaded agent from: ", config.resume_path)

  # collector
  buffer = VectorReplayBuffer(config.buffer_size, len(train_envs))

  # SG: exploration_noise=True no effect: policy._noise is None default.
  train_collector = Collector(
      policy, train_envs, buffer, exploration_noise=True)
  # SG: collect config.start_timesteps (10000) random steps as warmup
  # train_collector.collect(n_step=config.start_timesteps, random=True)

  # log
  now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
  git_hash = get_git_commit_hash()
  exp_name = f"{config.log_folder_suffix}_{git_hash}_{now}"
  params_name = f"dmodel{config.dmodel}_numo{config.num_options}_nhid{config.hidden_sizes[0]}"
  config.algo_name = "vmoc"
  log_name = os.path.join(config.env, config.algo_name, params_name,
                          str(config.seed), exp_name)
  log_path = os.path.join(config.logdir, log_name)

  print('log path: ', log_path, flush=True)
  writer = SummaryWriter(log_path)
  writer.add_text("config", str(config))
  logger = TensorboardLogger(writer)

  def save_best_fn(policy):
    torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

  test_collector = None
  if not config.watch:
    # trainer
    trainer = OffpolicyTrainer(
        policy,
        train_collector,
        test_collector,
        config.max_epoch,
        config.step_per_epoch,
        config.step_per_collect,
        config.eval_num_envs_per_worker,
        config.batch_size,
        save_best_fn=save_best_fn,
        logger=logger,
        update_per_step=config.gradient_step_per_collect,
        test_in_train=False,
        show_progress=config.show_progress,
    )
    result = trainer.run()
    pprint(result)


if __name__ == '__main__':
  test_vmoc(get_args())
