import argparse
import datetime
# import gym
from d3rlpy.datasets import get_pybullet
import numpy as np
import itertools
import torch
from tqdm import tqdm
from sac import SAC
from torch.utils.tensorboard import SummaryWriter
from replay_memory import ReplayMemory
from buffer import ReplayBuffer
from ..utils import transform_dataset
from .network import DynamicsModel

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="hopper-bullet-mixed-v0",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=123456, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--num_steps', type=int, default=1000001, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=1000000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_true",
                    help='run on CUDA (default: False)')
args = parser.parse_args()

NUM_EPOCHS = 10000
REWARD_PENALTY_COEFF = None
ROLLOUT_HORIZON = 2
ROLLOUT_BATCH_SIZE = 4000

# Environment
dataset, env = get_pybullet(args.env_name)

observations, actions, rewards, next_observations = transform_dataset(dataset)
inputs = np.concatenate([observations, actions], axis = 1)
mu = np.mean(inputs, axis = 0, keepdims = True)
sigma = np.std(inputs, axis = 0, keepdims = True)
sigma[sigma < 1e-12] = 1.0

env.seed(args.seed)
env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Model
dynamics_model = DynamicsModel(input_dim = observations.shape[1] + actions.shape[1],
	hidden_dim = 200, output_dim = observations.shape[1] + rewards.shape[1],
	learning_rate = 1e-3, name = "dynamics_model_v0")
dynamics_model.load_checkpoint("./data/%s/models/%s_best.bin" % (ENV_NAME, dynamics_model.name))

# Agent
agent = SAC(env.observation_space.shape[0], env.action_space, args)

#Tesnorboard
writer = SummaryWriter('runs/{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
                                                             args.policy, "autotune" if args.automatic_entropy_tuning else ""))

# Memory
memory = ReplayBuffer(args.replay_size, env.observation_space.shape[0], env.action_space.shape[0])

# Training Loop
total_numsteps = 0
updates = 0

score_list = []
for epoch in tqdm(range(NUM_EPOCHS)):
	state = observations[np.random.choice(len(observations), ROLLOUT_BATCH_SIZE)]

	for _ in range(ROLLOUT_HORIZON):
		action = agent.select_action(state)

		input = (np.concatenate([state, action], axis = 1) - mu) / sigma
		prediction = dynamics_model.predict(input)
		state_, reward = np.split(prediction, [*env.observation_space.shape], axis = 1)

		memory.store_batch(state, action, reward.flatten(), state_)

		state  =np.copy(state_)

	for _ in range(20):
		agent.update_parameters(memory, args.batch_size, updates)
		updates += 1

	if (epoch % 100) == 0:
		score_list.append(0)
		for _ in range(10):
			score = 0
			done = False
			state = env.reset()
			while not done:
				action = agent.select_action(state, evaluate = True)

				state_, reward, done, info = env.step(action)
				score += reward
				state = state_

			score_list[-1] += score / 10

		print(score_list[-1])