import torch
import numpy as np
import torchvision
from tqdm import tqdm
import torch.optim as optim 
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split

from network import DynamicsModel
from utils import Batcher, RolloutGenerator, qlearning_dataset, full_dataset

ENV_NAME = "walker2d_medium_replay-v2"
PIPELINE = {"ml_learning": True, "me_learning": True}
NUM_EPOCHS = {"ml_learning": 200, "me_learning": 200}

BATCH_SIZE = {"train": 512, "test": 512}
ROLLOUT = {"batch_size": 512, "length": 5}

random_seed = None

##########################################################################

if __name__ == "__main__":

	dataset = full_dataset(ENV_NAME)

	observations = dataset["observations"]
	actions = dataset["actions"]
	rewards = dataset["rewards"].reshape(-1, 1)
	next_observations = dataset["next_observations"]
	terminals = dataset["terminals"]

	if random_seed is not None:
		torch.manual_seed(random_seed)
		np.random.seed(random_seed)

	delta_observations = next_observations - observations
	inputs = np.concatenate([observations, actions], axis = 1)
	# targets = np.concatenate([delta_observations, rewards], axis = 1)
	targets = delta_observations

	mean = np.mean(inputs, axis = 0, keepdims = True)
	std = np.std(inputs, axis = 0, keepdims = True)
	std[std < 1e-12] = 1.0
	inputs = (inputs - mean) / std

	dynamics_model = DynamicsModel(input_dim = inputs.shape[1], hidden_dim = 200,
		output_dim = targets.shape[1], learning_rate = 3e-4, name = "dynamics_model_base")

	train_writer = SummaryWriter("model_runs/{}/ml_learning/random_seed_{}_num_epochs_{}/train".format(
													ENV_NAME, random_seed, NUM_EPOCHS["ml_learning"]))
	test_writer = SummaryWriter("model_runs/{}/ml_learning/random_seed_{}_num_epochs_{}/test".format(
													ENV_NAME, random_seed, NUM_EPOCHS["ml_learning"]))

	best_loss = np.inf

	# train_inputs, test_inputs, train_targets, test_targets =\
	# train_test_split(inputs, targets, test_size = 0.1)
	# train_batcher = Batcher(BATCH_SIZE['train'], [train_inputs, train_targets])
	# test_batcher = Batcher(BATCH_SIZE["test"], [test_inputs, test_targets])


	for epoch in tqdm(range(NUM_EPOCHS["ml_learning"])):

		train_inputs, test_inputs, train_targets, test_targets =\
		train_test_split(inputs, targets, test_size = 0.1)

		train_batcher = Batcher(BATCH_SIZE['train'], [train_inputs, train_targets])
		test_batcher = Batcher(BATCH_SIZE["test"], [test_inputs, test_targets])

		train_batcher.reset()
		train_batcher.shuffle()

		test_batcher.reset()
		test_batcher.shuffle()

		train_loss, train_mse, train_std = [0] * 3
		test_loss, test_mse, test_std = [0] * 3

		for _ in range(train_batcher.num_batches):

			network_input, target = train_batcher.next_batch()
			NLL_loss, MSE_loss, STD_loss = dynamics_model.fit(network_input, target)

			train_loss += NLL_loss * len(network_input) / train_batcher.num_entries
			train_mse += MSE_loss * len(network_input) / train_batcher.num_entries
			train_std += STD_loss * len(network_input) / train_batcher.num_entries

		for _ in range(test_batcher.num_batches):

			network_input, target = test_batcher.next_batch()
			NLL_loss, MSE_loss, STD_loss = dynamics_model.evaluate(network_input, target)

			test_loss += NLL_loss * len(network_input) / test_batcher.num_entries
			test_mse += MSE_loss * len(network_input) / test_batcher.num_entries
			test_std += STD_loss * len(network_input) / test_batcher.num_entries

		train_writer.add_scalar("Loss/Overall", train_loss, epoch)
		train_writer.add_scalar("Loss/MSE", train_mse, epoch)
		train_writer.add_scalar("Loss/STD", train_std, epoch)

		test_writer.add_scalar("Loss/Overall", test_loss, epoch)
		test_writer.add_scalar("Loss/MSE", test_mse, epoch)
		test_writer.add_scalar("Loss/STD", test_std, epoch)

		print("Epoch %d - Train Loss: %f - Test Loss: %f" % (epoch, train_loss, test_loss))
		print("Train MSE Loss: %f - Test MSE Loss: %f" % (train_mse, test_mse))
		# print(dynamics_model.max_logstd)
		# print(dynamics_model.min_logstd)
		dynamics_model.save_checkpoint("./data/%s/models" % (ENV_NAME))

		if test_loss < best_loss:
			best_loss = test_loss
			dynamics_model.save_best("./data/%s/models" % (ENV_NAME))

	train_writer.close()
	test_writer.close()