import torch
import numpy as np
import torchvision
from tqdm import tqdm
import torch.optim as optim 
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split

from network import DynamicsModel
from utils import Batcher, RolloutGenerator, qlearning_dataset, full_dataset

ENV_NAME = "walker2d_medium_replay-v2"
PIPELINE = {"ml_learning": True, "me_learning": True}
NUM_EPOCHS = {"ml_learning": 200, "me_learning": 200}

BATCH_SIZE = {"train": 512, "test": 512}
ROLLOUT = {"batch_size": 512, "length": 5}

random_seed = None

##########################################################################

if __name__ == "__main__":

	dataset = full_dataset(ENV_NAME)

	observations = dataset["observations"]
	actions = dataset["actions"]
	rewards = dataset["rewards"].reshape(-1, 1)
	next_observations = dataset["next_observations"]
	terminals = dataset["terminals"]

	if random_seed is not None:
		torch.manual_seed(random_seed)
		np.random.seed(random_seed)

	delta_observations = next_observations - observations
	inputs = np.concatenate([observations, actions], axis = 1)
	# targets = np.concatenate([delta_observations, rewards], axis = 1)
	targets = delta_observations

	# inputs = np.concatenate([observations, actions], axis = 1)
	# targets = np.concatenate([next_observations, rewards], axis = 1)

	mean = np.mean(inputs, axis = 0, keepdims = True)
	std = np.std(inputs, axis = 0, keepdims = True)
	std[std < 1e-12] = 1.0
	inputs = (inputs - mean) / std

	# min_input = np.min(inputs, axis = 0, keepdims = True)
	# din_input = np.max(inputs, axis = 0, keepdims = True) - min_input
	# din_input[din_input < 1e-12] = 1.0
	# inputs = ((inputs - min_input) * 2 / din_input) - 1


	# train_inputs, test_inputs, train_targets, test_targets =\
	# train_test_split(inputs, targets, test_size = 0.1)

	dynamics_model = DynamicsModel(input_dim = inputs.shape[1], hidden_dim = 200,
		output_dim = targets.shape[1], learning_rate = 3e-4, name = "dynamics_model_v0")

	if PIPELINE["ml_learning"]:

		train_writer = SummaryWriter("model_runs/{}/ml_learning/random_seed_{}_num_epochs_{}/train".format(
														ENV_NAME, random_seed, NUM_EPOCHS["ml_learning"]))
		test_writer = SummaryWriter("model_runs/{}/ml_learning/random_seed_{}_num_epochs_{}/test".format(
														ENV_NAME, random_seed, NUM_EPOCHS["ml_learning"]))

		best_loss = np.inf

		# train_batcher = Batcher(BATCH_SIZE['train'], [train_inputs, train_targets])
		# test_batcher = Batcher(BATCH_SIZE["test"], [test_inputs, test_targets])

		for epoch in tqdm(range(NUM_EPOCHS["ml_learning"])):

			train_inputs, test_inputs, train_targets, test_targets =\
			train_test_split(inputs, targets, test_size = 0.1)

			train_batcher = Batcher(BATCH_SIZE['train'], [train_inputs, train_targets])
			test_batcher = Batcher(BATCH_SIZE["test"], [test_inputs, test_targets])

			train_batcher.reset()
			train_batcher.shuffle()

			test_batcher.reset()
			test_batcher.shuffle()

			train_loss, train_mse, train_std = [0] * 3
			test_loss, test_mse, test_std = [0] * 3

			for _ in range(train_batcher.num_batches):

				network_input, target = train_batcher.next_batch()
				NLL_loss, MSE_loss, STD_loss = dynamics_model.fit(network_input, target)

				train_loss += NLL_loss * len(network_input) / train_batcher.num_entries
				train_mse += MSE_loss * len(network_input) / train_batcher.num_entries
				train_std += STD_loss * len(network_input) / train_batcher.num_entries

			for _ in range(test_batcher.num_batches):

				network_input, target = test_batcher.next_batch()
				NLL_loss, MSE_loss, STD_loss = dynamics_model.evaluate(network_input, target)

				test_loss += NLL_loss * len(network_input) / test_batcher.num_entries
				test_mse += MSE_loss * len(network_input) / test_batcher.num_entries
				test_std += STD_loss * len(network_input) / test_batcher.num_entries

			train_writer.add_scalar("Loss/Overall", train_loss, epoch)
			train_writer.add_scalar("Loss/MSE", train_mse, epoch)
			train_writer.add_scalar("Loss/STD", train_std, epoch)

			test_writer.add_scalar("Loss/Overall", test_loss, epoch)
			test_writer.add_scalar("Loss/MSE", test_mse, epoch)
			test_writer.add_scalar("Loss/STD", test_std, epoch)

			print("Epoch %d - Train Loss: %f - Test Loss: %f" % (epoch, train_loss, test_loss))
			print("Train MSE Loss: %f - Test MSE Loss: %f" % (train_mse, test_mse))
			dynamics_model.save_checkpoint("./data/%s/models" % (ENV_NAME))

			if test_loss < best_loss:
				best_loss = test_loss
				dynamics_model.save_best("./data/%s/models" % (ENV_NAME))

		train_writer.close()
		test_writer.close()

	if PIPELINE["me_learning"]:

		train_writer = SummaryWriter("model_runs/{}/me_learning/random_seed_{}_num_epochs_{}/train".format(
														ENV_NAME, random_seed, NUM_EPOCHS["me_learning"]))
		test_writer = SummaryWriter("model_runs/{}/me_learning/random_seed_{}_num_epochs_{}/test".format(
														ENV_NAME, random_seed, NUM_EPOCHS["me_learning"]))

		dynamics_model.load_checkpoint(
			"./data/%s/models/%s_best.bin" % (ENV_NAME, dynamics_model.name))
		dynamics_model.name = "dynamics_model_v1"

		""" Changing the learning rate requires further investigation. """
		# dynamics_model.optimizer = optim.Adam(dynamics_model.parameters(), lr = 0.0001)

		""" The sampler can be initialized with the original ML model as well """
		""" I am not sure if the dynamics model changes as the original dynamics model changes... """
		sampler = RolloutGenerator(dataset, ENV_NAME, dynamics_model)

		best_loss = np.inf

		# train_batcher = Batcher(BATCH_SIZE["train"], [train_inputs, train_targets])
		# test_batcher = Batcher(BATCH_SIZE["test"], [test_inputs, test_targets])

		for epoch in tqdm(range(NUM_EPOCHS["me_learning"])):

			train_inputs, test_inputs, train_targets, test_targets =\
			train_test_split(inputs, targets, test_size = 0.1)
			
			train_batcher = Batcher(BATCH_SIZE['train'], [train_inputs, train_targets])
			test_batcher = Batcher(BATCH_SIZE["test"], [test_inputs, test_targets])

			train_batcher.reset()
			train_batcher.shuffle()

			test_batcher.reset()
			test_batcher.shuffle()

			train_loss, train_mse, train_std = [0] * 3
			test_loss, test_mse, test_std = [0] * 3

			for _ in range(train_batcher.num_batches):

				network_input, target = train_batcher.next_batch()
				me_input = np.concatenate(sampler.sample(ROLLOUT["batch_size"], ROLLOUT["length"])[:2], axis = 1)
				me_input = sampler.scaler.transform(me_input)

				# overall_loss, MSE_loss, STD_loss = dynamics_model.fit(network_input, target, me_input, baseline)
				overall_loss, MSE_loss, STD_loss = dynamics_model.fit(network_input, target, me_input)

				train_loss += overall_loss * len(network_input) / train_batcher.num_entries
				train_mse += MSE_loss * len(network_input) / train_batcher.num_entries
				train_std += STD_loss * len(network_input) / train_batcher.num_entries

			for _ in range(test_batcher.num_batches):

				network_input, target = test_batcher.next_batch()
				me_input = np.concatenate(sampler.sample(ROLLOUT["batch_size"], ROLLOUT["length"])[:2], axis = 1)
				me_input = sampler.scaler.transform(me_input)

				# overall_loss, MSE_loss, STD_loss = dynamics_model.evaluate(network_input, target, me_input, baseline)
				overall_loss, MSE_loss, STD_loss = dynamics_model.evaluate(network_input, target, me_input)

				test_loss += overall_loss * len(network_input) / test_batcher.num_entries
				test_mse += MSE_loss * len(network_input) / test_batcher.num_entries
				test_std += STD_loss * len(network_input) / test_batcher.num_entries

			""" Updating the sampler model should be investigated further. """
			sampler.update_model(dynamics_model)

			train_writer.add_scalar("Loss/Overall", train_loss, epoch)
			train_writer.add_scalar("Loss/MSE", train_mse, epoch)
			train_writer.add_scalar("Loss/STD", train_std, epoch)

			test_writer.add_scalar("Loss/Overall", test_loss, epoch)
			test_writer.add_scalar("Loss/MSE", test_mse, epoch)
			test_writer.add_scalar("Loss/STD", test_std, epoch)

			print("Epoch %d - Train Loss: %f - Test Loss: %f" % (epoch, train_loss, test_loss))
			print("Train MSE Loss: %f - Test MSE Loss: %f" % (train_mse, test_mse))
			dynamics_model.save_checkpoint("./data/%s/models" % (ENV_NAME))

			if test_loss < best_loss:
				best_loss = test_loss
				dynamics_model.save_best("./data/%s/models" % (ENV_NAME))

		train_writer.close()
		test_writer.close()