import torch
import argparse
import datetime
import numpy as np
from tqdm import tqdm

from SAC.sac import SAC
from network import DynamicsModel
from SAC.buffer import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from utils import StandardScaler, qlearning_dataset, full_dataset
from termination_functions import termination_function

import matplotlib.pyplot as plt
import gym
import d4rl

random_seed = None

##########################################################################

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="walker2d_medium_replay-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=random_seed, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--num_steps', type=int, default=1000001, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=1000000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_true",
                    help='run on CUDA (default: False)')
args = parser.parse_args()

##########################################################################

NUM_EPOCHS = 2000
REWARD_PENALTY_COEFF = 1
ROLLOUT_HORIZON = 4
ROLLOUT_BATCH_SIZE = 100000
ROLLOUT_FREQ = 1
UPDATES_PER_STEP = 1000
TEST_FREQ = 10
TEST_EPISODES = 10

##########################################################################

env = gym.make("walker2d-medium-replay-v2")

##########################################################################

dataset = full_dataset(args.env_name)

##########################################################################

if args.seed is not None:
	env.seed(args.seed)
	torch.manual_seed(args.seed)
	np.random.seed(args.seed)

##########################################################################

observations = dataset["observations"]
actions = dataset["actions"]
rewards = dataset["rewards"].reshape(-1, 1)
next_observations = dataset["next_observations"]
terminals = dataset["terminals"]

scaler = StandardScaler()
scaler.fit(np.concatenate([observations, actions], axis = 1))

##########################################################################

dynamics_model = DynamicsModel(input_dim = observations.shape[1] + actions.shape[1],
								hidden_dim = 200, output_dim = observations.shape[1],
									learning_rate = 1e-3, name = "dynamics_model_v1")
dynamics_model.load_checkpoint("./data/%s/models/%s_best.bin" % (args.env_name, dynamics_model.name))

##########################################################################

# writer = SummaryWriter(
# 	"runs/{}/random_seed_{}_update_per_step{}_replay_size_{}_rollout_freq_{}_rollout_batch_size{}_rollout_horizon_{}_coeff_{}_{}".format(
# 															args.env_name, args.seed, UPDATES_PER_STEP, args.replay_size, ROLLOUT_FREQ,
# 															ROLLOUT_BATCH_SIZE, ROLLOUT_HORIZON, REWARD_PENALTY_COEFF,dynamics_model.name))

##########################################################################

agent = SAC(observations.shape[1] - 1, (actions.shape[1],), args)

##########################################################################

# memory = ReplayBuffer(args.replay_size, env.observation_space.shape[0], env.action_space.shape[0],
# 				data = (observations, actions, rewards.flatten(), next_observations, terminals))
memory = ReplayBuffer(args.replay_size, observations.shape[1] - 1, actions.shape[1])

##########################################################################

RANDOM_EXPLORATION = False
RANDOM_EXPLORATION_COUNTER_MAX = 20

test_scores = []
updates = 0
random_exploration_counter = 0
for epoch in tqdm(range(NUM_EPOCHS)):
	if (epoch % ROLLOUT_FREQ) == 0:

		# for _ in range(int(memory.mem_size / (ROLLOUT_BATCH_SIZE * ROLLOUT_HORIZON))):
		# 	state = observations[np.random.choice(len(observations), ROLLOUT_BATCH_SIZE)]

		# 	for _ in range(ROLLOUT_HORIZON):
		# 		if RANDOM_EXPLORATION:
		# 			action = np.random.uniform(-1, 1, (len(state), actions.shape[1]))
		# 			random_exploration_counter += 1

		# 			if random_exploration_counter >= RANDOM_EXPLORATION_COUNTER_MAX:
		# 				RANDOM_EXPLORATION = False
		# 		else:
		# 			action = agent.select_action(state[:, 1:])

		# 		input = scaler.transform(np.concatenate([state, action], axis = 1))
		# 		prediction = dynamics_model.predict(input)
		# 		# state_, reward = np.split(prediction, [observations.shape[1]], axis = 1)
		# 		# reward, state_ = np.split(prediction, [1], axis = 1)

		# 		reward = prediction[:, 0] / 0.008 - 0.001 * np.sum(action ** 2, axis = 1) + 1  # This is for Hopper Environment
		# 		# reward = prediction[:, 0] / 0.05 - 0.1 * np.sum(action ** 2, axis = 1) # This is for HalfCheetah Environment

		# 		input = torch.FloatTensor(input).to(dynamics_model.device)
		# 		_, log_std = dynamics_model.forward(input)
		# 		std = torch.exp(log_std).cpu().detach().numpy()
		# 		penalty = np.sqrt(np.sum(std ** 2, axis = 1))

		# 		next_state = state + prediction
		# 		done = termination_function(state[:, 1:], action, next_state[:, 1:], args.env_name.split("_")[0])

		# 		memory.store_batch(state[:, 1:], action, reward.flatten() - REWARD_PENALTY_COEFF * penalty,
		# 																next_state[:, 1:], done.flatten())

		# 		state = state + prediction
		# 		state = state[~done.flatten()]
		# 		if len(state) == 0:
		# 			break
		state = observations[np.random.choice(len(observations), ROLLOUT_BATCH_SIZE)]

		for _ in range(ROLLOUT_HORIZON):
			if RANDOM_EXPLORATION:
				action = np.random.uniform(-1, 1, (len(state), actions.shape[1]))
				random_exploration_counter += 1

				if random_exploration_counter >= RANDOM_EXPLORATION_COUNTER_MAX:
					RANDOM_EXPLORATION = False
			else:
				action = agent.select_action(state[:, 1:])

			input = scaler.transform(np.concatenate([state, action], axis = 1))
			prediction = dynamics_model.predict(input)
			# state_, reward = np.split(prediction, [observations.shape[1]], axis = 1)
			# reward, state_ = np.split(prediction, [1], axis = 1)

			reward = prediction[:, 0] / 0.008 - 0.001 * np.sum(action ** 2, axis = 1) + 1  # This is for Hopper Environment
			# reward = prediction[:, 0] / 0.05 - 0.1 * np.sum(action ** 2, axis = 1) # This is for HalfCheetah Environment

			input = torch.FloatTensor(input).to(dynamics_model.device)
			_, log_std = dynamics_model.forward(input)
			std = torch.exp(log_std).cpu().detach().numpy()
			penalty = np.sqrt(np.sum(std ** 2, axis = 1))

			next_state = state + prediction
			done = termination_function(state[:, 1:], action, next_state[:, 1:], args.env_name.split("_")[0])

			memory.store_batch(state[:, 1:], action, reward.flatten() - REWARD_PENALTY_COEFF * penalty,
																	next_state[:, 1:], done.flatten())

			state = state + prediction
			state = state[~done.flatten()]
			if len(state) == 0:
				break

	# for _ in range(UPDATES_PER_STEP * (int(epoch / 500) + 1)):
	for _ in range(UPDATES_PER_STEP):
		# critic_1_loss, critic_2_loss, policy_loss, ent_loss, _ =\
		# agent.update_parameters(memory, args.batch_size, updates)
		critic_1_loss, critic_2_loss, policy_loss, ent_loss, _ =\
		agent.update_parameters(memory, args.batch_size, updates, observations[:, 1:], actions, rewards, next_observations[:, 1:], terminals)

		# writer.add_scalar('Loss/Critic_1', critic_1_loss, updates)
		# writer.add_scalar('Loss/Critic_2', critic_2_loss, updates)
		# writer.add_scalar('Loss/Policy', policy_loss, updates)
		# writer.add_scalar('Loss/Entropy_Loss', ent_loss, updates)
		updates += 1

	if (epoch % TEST_FREQ) == 1:
		avg_reward = 0

		for _ in range(TEST_EPISODES):
			done = False
			state = env.reset()
			s_temp = torch.FloatTensor(state).reshape(1, -1).to(agent.device)
			a_temp = torch.FloatTensor(agent.select_action(state, evaluate = True)).reshape(1, -1).to(agent.device)
			q_val = torch.min(*agent.critic(s_temp, a_temp))
			print(q_val)
			while not done:
				action = agent.select_action(state, evaluate = True)

				state_, reward, done, info = env.step(action)
				avg_reward += reward / TEST_EPISODES
				state = state_

		# writer.add_scalar('avg_reward/test', avg_reward, epoch)

		print("----------------------------------------")
		print("Test Episodes: {}, Avg. Reward: {}".format(TEST_EPISODES, round(avg_reward, 2)))
		print("----------------------------------------")

		test_scores.append(round(avg_reward, 2))

		agent.save_checkpoint(args.env_name)

plt.plot(test_scores)
plt.show()

# avg_reward = 0

# for _ in range(TEST_EPISODES):
# 	done = False
# 	state = env.reset()
# 	while not done:
# 		action = agent.select_action(state, evaluate = True)

# 		state_, reward, done, info = env.step(action)
# 		avg_reward += reward / TEST_EPISODES
# 		state = state_

# writer.add_scalar('avg_reward/test', avg_reward, epoch)

# print("----------------------------------------")
# print("Test Episodes: {}, Avg. Reward: {}".format(TEST_EPISODES, round(avg_reward, 2)))
# print("----------------------------------------")