import numpy as np
import torch
import gym
import argparse
import os
import d4rl
import datetime
import buffer
from buffer import ReplayBuffer
from a2po import OffRL
from tensorboardX import SummaryWriter
from datetime import datetime
import h5py
from tqdm import tqdm
now = datetime.now()


current_time = now.strftime("%H:%M:%S")
def eval_policy(policy, env, seed, mean, std, seed_offset=100, eval_episodes=10):
	env.seed(seed + seed_offset)

	avg_reward = 0.
	for _ in range(eval_episodes):
		state, done = env.reset(), False
		while not done:
			state = (np.array(state).reshape(1,-1) - mean)/std
			action = policy.select_action(state)
			state, reward, done, _ = env.step(action)
			avg_reward += reward

	avg_reward /= eval_episodes
	d4rl_score = env.get_normalized_score(avg_reward) * 100
	return d4rl_score


if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	# Experiment
	parser.add_argument("--algo", default="AC")        				# OpenAI gym environment name
	parser.add_argument("--env", default="halfcheetah-expert-v2")   # OpenAI gym environment name
	parser.add_argument("--seed", default=1, type=int)              # Sets Gym, PyTorch and Numpy seeds
	parser.add_argument("--eval_freq", default=4e4, type=int)       # How often (time steps) we evaluate
	parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment
	parser.add_argument("--save_model", action="store_true")        # Save model and optimizer parameters
	# TD3
	parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
	parser.add_argument("--discount", default=0.99)                 # Discount factor
	parser.add_argument("--tau", default=0.005)                     # Target network update rate
	parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
	parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
	parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updates
	parser.add_argument("--epsilon", default=-0.0, type=float)      # positive threshold \epsilon
	parser.add_argument("--bc_weight", default=1.0, type=float)     # BC term weight
	parser.add_argument("--use_cuda", default=True, type=bool)      # whether use gpu
	parser.add_argument("--vae_step", default=200000, type=int)     # VAE train step K
	parser.add_argument("--use_discrete", default=False, type=bool) # whether use discrete \xi
	parser.add_argument("--alpha", default=1.0)       				# max Q weight
	parser.add_argument("--normalize", default=True)       			# Q-normalization
	args = parser.parse_args()
	device = torch.device("cpu")
	if args.use_cuda:
		device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	print("---------------------------------------")
	print(f"Setting: Training {args.algo}, Env: {args.env}, Seed: {args.seed}, Discrete: {args.use_discrete}")
	print("---------------------------------------")
	env, dataset = buffer.get_env_dataset(args.env)
	# Set seeds
	env.seed(args.seed)
	env.action_space.seed(args.seed)
	torch.manual_seed(args.seed)
	np.random.seed(args.seed)

	state_dim = env.observation_space.shape[0]
	action_dim = env.action_space.shape[0]
	max_action = float(env.action_space.high[0])

	kwargs = {
		"state_dim": state_dim,
		"action_dim": action_dim,
		"max_action": max_action,
		"discount": args.discount,
		"tau": args.tau,
		"device": device,
		# TD3
		"policy_noise": args.policy_noise * max_action,
		"noise_clip": args.noise_clip * max_action,
		"policy_freq": args.policy_freq,
		"alpha": args.alpha,
		# AC
		"epsilon": args.epsilon,
		"bc_weight": args.bc_weight,
		"vae_step": args.vae_step,
		"use_discrete": args.use_discrete,
	}

	# Initialize policy
	policy = OffRL(**kwargs)

	replay_buffer = ReplayBuffer(state_dim, action_dim)
	replay_buffer.convert_D4RL(dataset)
	if args.normalize:
		mean,std = replay_buffer.normalize_states() 
	else:
		mean,std = 0,1
	writer = SummaryWriter(
		logdir=f'runs/{args.algo}_{args.env}_{args.seed}_{current_time}'
	)
	evaluations = []
	for t in tqdm(range(int(args.max_timesteps)), desc='PI training'):
		policy.policy_train(replay_buffer, writer, args.algo, args.env, args.batch_size)
		# Evaluate episode
		if t % args.eval_freq == 0:
			eval_res = eval_policy(policy, env, args.seed, mean, std)
			evaluations.append(eval_res)
			writer.add_scalar(f'{args.env}/eval_reward', eval_res, t)
			print(f"| {args.algo} | {args.env}_{args.seed} | iterations: {t} | eval_reward: {eval_res} |")
