import numpy as np
import gc
import scipy.stats
import copy
from absl import app, flags
import time

from games.oh_hell import OhHellGame
from algorithms.enumerate import StateEnumerator
from algorithms.gibbs import GibbsSampler
from algorithms.importance import ImportanceSampler
from algorithms.util import RandomJointPolicy, expected_value, RLPolicyWrapper
from algorithms.rl import LinearQAgent

FLAGS = flags.FLAGS

flags.DEFINE_integer('num_samples', 1000, 'number of samples to draw from gibbs')
flags.DEFINE_integer('num_tricks', 2, 'number of tricks in the game (cards per player)')
flags.DEFINE_integer('num_tricks_played', 1, 'number of tricks played before the experiment')
flags.DEFINE_integer('cards_in_current_trick', 0, 'number of cards played in current trick before the experiment')
flags.DEFINE_integer('num_suits', 2, 'number of suits in the game')
flags.DEFINE_integer('num_ranks', 4, 'number of cards per suit')
flags.DEFINE_integer('num_repeats', 1, 'number of times to repeat the experiment')
flags.DEFINE_float('epsilon', 0.05, 'epsilon greedy for RL policies')
flags.DEFINE_float('learning_rate', 0.01, 'Learning rate for RL policies')
flags.DEFINE_integer('num_episodes', 10000, 'Number of episodes to train RL')
flags.DEFINE_boolean('time', False, 'Output time at each stage.')


def print_metadata():
	out = f'num_suits={FLAGS.num_suits},num_ranks={FLAGS.num_ranks},num_repeats={FLAGS.num_repeats},epsilon={FLAGS.epsilon},'
	out += f'learning_rate={FLAGS.learning_rate},num_episodes={FLAGS.num_episodes},num_tricks={FLAGS.num_tricks},'
	out += f'num_tricks_played={FLAGS.num_tricks_played},cards_in_current_trick={FLAGS.cards_in_current_trick}'
	print(out)

def output_samples(sampler, samples):
	print(f'{sampler.name}:{samples}')

def find_goal_state(state, policy):
	# set the number of tricks
	chance_player = state.get_player_to_move()
	state.play(chance_player, FLAGS.num_tricks)
	# find a state that matches the input criteria, following the policy
	while not state.terminal():
		player = state.get_player_to_move()
		actions = state.get_legal_actions(player)
		if player >= 0:
			probs = policy.get_action_probabilities(state.get_infostate(player), player, actions)
			action = np.random.choice(actions, p=probs)
			if state.num_tricks_played() >= FLAGS.num_tricks_played and state.num_cards_played_current_trick() >= FLAGS.cards_in_current_trick:
				return state
		else:
			action = np.random.choice(actions)
		state.play(player, action)
	ValueError("STATE NOT FOUND")

def ev_experiment(state, policy):
	start_time  = time.perf_counter()
	state = find_goal_state(state, policy)
	enumerator = StateEnumerator(state, policy)
	value = enumerator.expected_value()[0]
	if FLAGS.time:
		print(f'enumerator set @ {time.perf_counter() - start_time:0.4f}')
	print_metadata()
	print(f'value:{value}')
	# set up the samplers
	samplers = [enumerator,
				ImportanceSampler(policy, enumerator),
				GibbsSampler(enumerator.sample_history_uniformly(), policy)]
	for sampler in samplers:
		samples = sampler.generate_samples(FLAGS.num_samples)
		if FLAGS.time:
			print(f'sampler finished @ {time.perf_counter() - start_time:0.4f}')
		output_samples(sampler, samples)

def train_rl_policy(game, agent):
	for ep in range(FLAGS.num_episodes):
		episode = [[] for _ in range(game.num_players)]
		state = game.new_initial_state()
		# set the number of tricks
		chance_player = state.get_player_to_move()
		state.play(chance_player, FLAGS.num_tricks)
		while not state.terminal():
			player = state.get_player_to_move()
			legal_actions = state.get_legal_actions(player)
			if player >= 0:
				action = agent.act(state.get_infostate(player), player, legal_actions)
				episode[player].append((copy.deepcopy(state), action))
			else:
				action = np.random.choice(legal_actions)
			state.play(player, action)
		scores = state.score()
		agent.feedback(episode, scores)
	return agent

def main(_):
	game = OhHellGame(num_suits=FLAGS.num_suits, num_ranks=FLAGS.num_ranks)
	agent = LinearQAgent(game, FLAGS.epsilon, FLAGS.learning_rate)
	agent = train_rl_policy(game, agent)
	policy = RLPolicyWrapper(agent)
	for _ in range(FLAGS.num_repeats):
		state = game.new_initial_state()
		ev_experiment(state, policy)

if __name__ == "__main__":
	app.run(main)
