# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Policy gradient agents trained and evaluated on Kuhn Poker."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v1 as tf

from open_spiel.python import policy
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import exploitability
import policy_gradient_ as policy_gradient
import numpy as np
import time
import argparse
import os.path
import sys

FLAGS = flags.FLAGS

flags.DEFINE_integer("num_episodes", int(1e7), "Number of train episodes.")
flags.DEFINE_integer("eval_every", int(1e4), "Eval agents every x episodes.")
flags.DEFINE_integer("save_every", int(1e6), "Eval agents every x episodes.")
flags.DEFINE_enum("loss_str", "ach", ["a2c", "rpg", "qpg", "rm","neurd","ach"],
                  "PG loss to use.")
flags.DEFINE_string("game_name","kuhn_poker","Name of the game")
flags.DEFINE_integer("batch_size",64, "batchsize of network")
flags.DEFINE_float("critic_lr", 0.001,"learning rate of critic network. Also learning rate of ACH.")
flags.DEFINE_float("pi_lr",0.001,"learning rate of policy network")
flags.DEFINE_float('etpcost',0.01,"entropy cost")
flags.DEFINE_integer('nctopi',1,"num critic before pi")
flags.DEFINE_string("optm","sgd","optimizer for training")
flags.DEFINE_integer("seed_",7,"set seed for random")
flags.DEFINE_float("threshold",2,"threshold for clip method")
flags.DEFINE_float("hedge",float(1),"Hedge decay")
flags.DEFINE_float("alpha",2,"Alpha of Loss")
flags.DEFINE_bool("use_checkpoints", True, "Save/load neural network weights.")
flags.DEFINE_string("checkpoint_dir", "tmp/",
                    "Directory to save/load the agent.")

def writefile(msg,filename):
    with open(filename,'a') as f:
        f.write(msg)
        f.close()


class PolicyGradientPolicies(policy.Policy):
  """Joint policy to be evaluated."""

  def __init__(self, env, nfsp_policies):
    game = env.game
    player_ids = [0, 1]
    super(PolicyGradientPolicies, self).__init__(game, player_ids)
    self._policies = nfsp_policies
    self._obs = {"info_state": [None, None], "legal_actions": [None, None]}

  def action_probabilities(self, state, player_id=None):
    cur_player = state.current_player()
    legal_actions = state.legal_actions(cur_player)

    self._obs["current_player"] = cur_player
    self._obs["info_state"][cur_player] = (
        state.information_state_tensor(cur_player))
    self._obs["legal_actions"][cur_player] = legal_actions

    info_state = rl_environment.TimeStep(
        observations=self._obs, rewards=None, discounts=None, step_type=None)

    p = self._policies[cur_player].step(info_state, is_evaluation=True).probs
    prob_dict = {action: p[action] for action in legal_actions}
    return prob_dict

def main(_):
  np.random.seed(FLAGS.seed_)
  tf.set_random_seed(FLAGS.seed_)
  game = FLAGS.game_name
  num_players = 2
  
  
  env_configs = {"players": num_players}
  chance_event_sampler = rl_environment.ChanceEventSampler(seed=FLAGS.seed_)
  env = rl_environment.Environment(game,**env_configs,chance_event_sampler = chance_event_sampler) 
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]
  filename = FLAGS.game_name + "_" + FLAGS.loss_str + "_cl" + str(FLAGS.critic_lr) + "_pl" + str(FLAGS.pi_lr) + "_ec"+ str(FLAGS.etpcost) + "_sd" + str(FLAGS.seed_)+"_hg" + str(FLAGS.hedge) + "_th" + str(FLAGS.threshold)
  FLAGS.checkpoint_dir = FLAGS.checkpoint_dir + filename
  filename = "log/" + filename

  with tf.Session() as sess:
    agents = [
        policy_gradient.PolicyGradient(
            sess,
            idx,
            info_state_size,
            num_actions,
            loss_str=FLAGS.loss_str,
            hidden_layers_sizes=(128,),
            batch_size = FLAGS.batch_size,
            critic_learning_rate= FLAGS.critic_lr,
            pi_learning_rate= FLAGS.pi_lr,
            entropy_cost= FLAGS.etpcost,
            threshold = FLAGS.threshold,
            num_critic_before_pi= FLAGS.nctopi,
            optimizer_str = FLAGS.optm,
            hedge_ = FLAGS.hedge,
            alpha = FLAGS.alpha) for idx in range(num_players)
    ]

    expl_policies_avg = PolicyGradientPolicies(env, agents)
    
    sess.run(tf.global_variables_initializer())
    
    
#    if FLAGS.use_checkpoints:
#      for agent in agents:
#        print(FLAGS.checkpoint_dir+"_iter"+str(9000000))
#        if os.path.exists(FLAGS.checkpoint_dir+"_iter"+str(10000000)):
#          sys.exit()
       # if agent.has_checkpoint(FLAGS.checkpoint_dir+"_iter"+str(9000000)):
       #   print("############LOAD!!!#############\n")
#        else: 
#          agent.restore(FLAGS.checkpoint_dir+"_iter"+str(9000000))
       # else:
       #   print("#####NOT LOADED!############\n")
    
    
    for ep in range(FLAGS.num_episodes):
     # if ep<9000000:
     #   ep = 9000000
      if (ep + 1) % FLAGS.eval_every == 0:
        losses = [agent.loss for agent in agents]
        expl = exploitability.exploitability(env.game, expl_policies_avg)
        msg = "{} {} {} {}\n".format(time.asctime(time.localtime(time.time())),ep + 1, expl, losses)
        print(msg)
        writefile(msg,filename)
        if (ep + 1) % FLAGS.save_every == 0 and FLAGS.use_checkpoints:
          for agent in agents:
            agent.save(FLAGS.checkpoint_dir+"_iter"+str(ep + 1))

      time_step = env.reset()
      while not time_step.last():
        player_id = time_step.observations["current_player"]
        agent_output = agents[player_id].step(time_step)
        action_list = [agent_output.action]
        #print(action_list)
        time_step = env.step(action_list)

      # Episode is over, step all agents with final info state.
      for agent in agents:
        agent.step(time_step)


if __name__ == "__main__":
  app.run(main)
