import os
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.distributions.normal import Normal

from utils import Swish, init_weights, mlp

# LOG_STD_BOUNDS = (-5, 0.25)
LOG_STD_BOUNDS = (-5, 1)
# LOG_STD_BOUNDS = (-5.0, 1.5)

##########################################################################

# def swish(input):
# 	return input * torch.sigmoid(input)

##########################################################################

# class DynamicsModel(nn.Module):
# 	def __init__(self, input_dimension, hidden_dimension, output_dimension,
# 		learning_rate, name):
# 		super(DynamicsModel, self).__init__()

# 		self.name = name
# 		self.learning_rate = learning_rate
# 		self.input_dimension = input_dimension
# 		self.hidden_dimension = hidden_dimension
# 		self.output_dimension = output_dimension

# 		self.fully_connected1 = nn.Linear(self.input_dimension, self.hidden_dimension)
# 		self.fully_connected2 = nn.Linear(self.hidden_dimension, self.hidden_dimension)
# 		self.fully_connected3 = nn.Linear(self.hidden_dimension, self.hidden_dimension)
# 		self.fully_connected4 = nn.Linear(self.hidden_dimension, self.hidden_dimension)
# 		self.fully_connected5 = nn.Linear(self.hidden_dimension, self.output_dimension)
# 		self.fully_connected6 = nn.Linear(self.hidden_dimension, self.output_dimension)

# 		self.init_weights()

# 		self.optimizer = optim.Adam(self.parameters(), lr = self.learning_rate)
# 		self.loss_function = nn.GaussianNLLLoss()

# 		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 		self.to(self.device)

# 	def forward(self, input):

# 		input = swish(self.fully_connected1(input))
# 		input = swish(self.fully_connected2(input))
# 		input = swish(self.fully_connected3(input))
# 		input = swish(self.fully_connected4(input))

# 		mean_prediction = self.fully_connected5(input)
# 		variance_prediction = torch.exp(self.fully_connected6(input))

# 		return mean_prediction, variance_prediction

# 	# def forward(self, input):

# 	# 	input = swish(self.fully_connected1(input))
# 	# 	input = swish(self.fully_connected2(input))
# 	# 	input = swish(self.fully_connected3(input))
# 	# 	input = swish(self.fully_connected4(input))

# 	# 	mean_pred = self.fully_connected5(input)
# 	# 	log_std_pred = self.fully_connected6(input)

# 	def fit(self, input, target):

# 		input = torch.FloatTensor(input).to(self.device)
# 		target = torch.FloatTensor(target).to(self.device)

# 		mean_prediction, variance_prediction = self.forward(input)
# 		loss = self.loss_function(mean_prediction, target, variance_prediction)

# 		self.optimizer.zero_grad()
# 		loss.backward()
# 		# for parameter in self.parameters():
# 		# 	parameter.grad.data.clamp(-0.1, 0.1)
# 		self.optimizer.step()

# 		return loss.cpu().data.numpy().item()

# 	def evaluate(self, input, target):

# 		input = torch.FloatTensor(input).to(self.device)
# 		target = torch.FloatTensor(target).to(self.device)

# 		mean_prediction, variance_prediction = self.forward(input)
# 		loss = self.loss_function(mean_prediction, target, variance_prediction)

# 		return loss.cpu().data.numpy().item()

# 	def init_weights(self):
# 		for layer in self.parameters():
# 			if isinstance(layer, nn.Linear):
# 				init = 1. / np.sqrt(layer.weight.data.size()[0])
# 				nn.init.uniform_(layer.weight.data, -init, init)
# 				nn.init.constant_(layer.bias.data, 0)

# 	def save_checkpoint(self, checkpoint_directory):
# 		print("... saving checkpoint ...")
# 		if not os.path.exists(checkpoint_directory):
# 			os.makedirs(checkpoint_directory)
# 		torch.save(self.state_dict(), checkpoint_directory + "/" + self.name + ".bin")

# 	def load_checkpoint(self, checkpoint_directory):
# 		print("... loading checkpoint ...")
# 		self.load_state_dict(torch.load(checkpoint_directory))

# 	def save_best(self, checkpoint_directory):
# 		print("... saving best checkpoint ...")
# 		if not os.path.exists(checkpoint_directory):
# 			os.makedirs(checkpoint_directory)
# 		torch.save(self.state_dict(), checkpoint_directory + "/" + self.name + "_best.bin")

##########################################################################

# class DynamicsModel(nn.Module):
# 	def __init__(self, input_dim, hidden_dim, output_dim,
# 		learning_rate = 0.001, alpha = 1, name = "dynamics_model"):
# 		super(DynamicsModel, self).__init__()

# 		self.name = name
# 		self.alpha = alpha
# 		self.input_dim = input_dim
# 		self.hidden_dim = hidden_dim
# 		self.output_dim = output_dim
# 		self.hidden_depth = 4
# 		self.learning_rate = learning_rate

# 		self.network = mlp(self.input_dim, self.hidden_dim,
# 			self.output_dim * 2, self.hidden_depth, activation = Swish())

# 		self.optimizer = optim.Adam([
# 			{"params": self.network[0].parameters(), "weight_decay": 0.000025},
# 			{"params": self.network[2].parameters(), "weight_decay": 0.00005},
# 			{"params": self.network[4].parameters(), "weight_decay": 0.000075},
# 			{"params": self.network[6].parameters(), "weight_decay": 0.000075},
# 			{"params": self.network[8].parameters(), "weight_decay": 0.0001}
# 			], lr = self.learning_rate)

# 		self.apply(init_weights)

# 		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 		self.to(self.device)

# 	def forward(self, input):

# 		mean, log_std = self.network(input).chunk(2, dim = -1)
# 		log_std = torch.clamp(log_std, *LOG_STD_BOUNDS)

# 		return mean, log_std

# 	def likelihood_loss(self, input, target):

# 		input = torch.FloatTensor(input).to(self.device)
# 		target = torch.FloatTensor(input).to(self.device)

# 		mean, log_std = self.forward(input)
# 		loss = nn.GaussianNLLLoss(mean, target, torch.exp(log_std).pow(2))

# 		return loss

# 	def entropy_loss(self, input):

# 		input = torch.FloatTensor(input).to(self.device)

# 		mean, log_std = self.forward(input)
# 		distribution = Normal(mean, torch.exp(log_std))

# 		prediction = distribution.rsample()
# 		log_probs = distribution.log_prob(prediction)
# 		log_probs = log_probs.sum(dim = 1)
# 		loss = torch.mean(log_probs)

# 		return loss

# 	def fit(self, likelihood_loss, entropy_loss = None):

# 		""" 
		
# 		I am not sure about the naming, it might cause some future problems
		
# 		"""

# 		if entropy_loss is not None:
# 			loss = likelihood_loss + self.alpha * entropy_loss
# 		else:
# 			loss = likelihood_loss

# 		self.optimizer.zero_grad()
# 		loss.backward()
# 		# for parameter in self.parameters():
# 		# 	parameter.grad.data.clamp(-0.1, 0.1)
# 		self.optimizer.step()

# 		return loss.cpu().data.numpy().item()

# 	def evaluate(self, likelihood_loss, entropy_loss = None):

# 		if entropy_loss is not None:
# 			loss = likelihood_loss + self.alpha * entropy_loss
# 		else:
# 			loss = likelihood_loss

# 		return loss.cpu().data.numpy().item()

# 	def save_checkpoint(self, checkpoint_directory):
# 		print("... saving checkpoint ...")
# 		if not os.path.exists(checkpoint_directory):
# 			os.makedirs(checkpoint_directory)
# 		torch.save(self.state_dict(), checkpoint_directory + "/" + self.name + ".bin")

# 	def load_checkpoint(self, checkpoint_directory):
# 		print("... loading checkpoint ...")
# 		self.load_state_dict(torch.load(checkpoint_directory))

# 	def save_best(self, checkpoint_directory):
# 		print("... saving best checkpoint ...")
# 		if not os.path.exists(checkpoint_directory):
# 			os.makedirs(checkpoint_directory)
# 		torch.save(self.state_dict(), checkpoint_directory + "/" + self.name + "_best.bin")

##########################################################################

class DynamicsModel(nn.Module):
	def __init__(self, input_dim, hidden_dim, output_dim,
		learning_rate = 0.0003, alpha = 0.01, name = "dynamics_model"):
		super(DynamicsModel, self).__init__()

		self.name = name
		self.alpha = alpha
		self.input_dim = input_dim
		self.hidden_dim = hidden_dim
		self.output_dim = output_dim
		self.hidden_depth = 4
		self.learning_rate = learning_rate

		self.network = mlp(self.input_dim, self.hidden_dim,
			self.output_dim * 2, self.hidden_depth, activation = Swish())

		# self.network = mlp(self.input_dim, self.hidden_dim,
		# 	self.output_dim * 2, self.hidden_depth)

		self.optimizer = optim.Adam([
			{"params": self.network[0].parameters(), "weight_decay": 0.000025},
			{"params": self.network[2].parameters(), "weight_decay": 0.00005},
			{"params": self.network[4].parameters(), "weight_decay": 0.000075},
			{"params": self.network[6].parameters(), "weight_decay": 0.000075},
			{"params": self.network[8].parameters(), "weight_decay": 0.0001}
			], lr = self.learning_rate)

		self.apply(init_weights)

		self.GaussianNLLLoss = nn.GaussianNLLLoss()
		self.MSELoss = nn.MSELoss()

		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		# self.device = torch.device("cpu")
		self.to(self.device)

	def forward(self, input):

		mean, log_std = self.network(input).chunk(2, dim = -1)
		log_std = torch.clamp(log_std, *LOG_STD_BOUNDS)

		return mean, log_std

	def fit(self, ml_input, ml_target, me_input = None, baseline = None):

		ml_input = torch.FloatTensor(ml_input).to(self.device)
		ml_target = torch.FloatTensor(ml_target).to(self.device)

		ml_mean, ml_log_std = self.forward(ml_input)
		# loss = self.GaussianNLLLoss(ml_mean, ml_target, torch.exp(ml_log_std * 2)) * self.output_dim
		# loss = (torch.mean(((ml_mean - ml_target) ** 2) / torch.exp(ml_log_std * 2)) + torch.mean(ml_log_std * 2)) * self.output_dim / 2
		loss = 0.5 * torch.sum(((ml_mean - ml_target) ** 2) / torch.exp(ml_log_std * 2) + ml_log_std * 2, dim = 1, keepdim = True)
		# loss = 0.5 * torch.mean(((ml_mean - ml_target) ** 2) / torch.exp(ml_log_std * 2) + ml_log_std * 2)

		mse_loss = self.MSELoss(ml_mean, ml_target)

		if baseline is not None:
			""" This could get better... """
			# loss = 0.5 * (loss - baseline).pow(2)
			loss = self.MSELoss(loss, baseline)
		else:
			loss = torch.mean(loss)

		if me_input is not None:
			me_input = torch.FloatTensor(me_input).to(self.device)

			me_mean, me_log_std = self.forward(me_input)
			distribution = Normal(me_mean, torch.exp(me_log_std))

			log_probs = distribution.log_prob(distribution.rsample())
			loss += self.alpha * torch.mean(log_probs.sum(dim = 1))

		self.optimizer.zero_grad()
		loss.backward()
		# for parameter in self.parameters():
		# 	parameter.grad.data.clamp_(-0.1, 0.1)
		self.optimizer.step()

		return loss.cpu().data.numpy().item(), mse_loss.cpu().data.numpy().item(), torch.mean(torch.exp(ml_log_std * 2)).cpu().data.numpy().item()

	def evaluate(self, ml_input, ml_target, me_input = None, baseline = None):

		with torch.no_grad():

			ml_input = torch.FloatTensor(ml_input).to(self.device)
			ml_target = torch.FloatTensor(ml_target).to(self.device)

			ml_mean, ml_log_std = self.forward(ml_input)
			# loss = self.GaussianNLLLoss(ml_mean, ml_target, torch.exp(ml_log_std * 2)) * self.output_dim
			# loss = (torch.mean(((ml_mean - ml_target) ** 2) / torch.exp(ml_log_std * 2)) + torch.mean(ml_log_std * 2)) * self.output_dim / 2
			loss = 0.5 * torch.sum(((ml_mean - ml_target) ** 2) / torch.exp(ml_log_std * 2) + ml_log_std * 2, dim = 1, keepdim = True)
			# loss = 0.5 * torch.mean(((ml_mean - ml_target) ** 2) / torch.exp(ml_log_std * 2) + ml_log_std * 2)

			mse_loss = self.MSELoss(ml_mean, ml_target)

			if baseline is not None:
				# loss = 0.5 * (loss - baseline).pow(2)
				loss = self.MSELoss(loss, baseline)
			else:
				loss = torch.mean(loss)

			if me_input is not None:
				me_input = torch.FloatTensor(me_input).to(self.device)

				me_mean, me_log_std = self.forward(me_input)
				distribution = Normal(me_mean, torch.exp(me_log_std))

				log_probs = distribution.log_prob(distribution.rsample())
				loss += self.alpha * torch.mean(log_probs.sum(dim = 1))

		return loss.cpu().data.numpy().item(), mse_loss.cpu().data.numpy().item(), torch.mean(torch.exp(ml_log_std * 2)).cpu().data.numpy().item()

	def predict(self, input):

		input = torch.FloatTensor(input).to(self.device)

		mean, log_std = self.forward(input)
		distribution = Normal(mean, torch.exp(log_std))

		return mean.cpu().detach().numpy()
		# return distribution.sample().cpu().detach().numpy()
		# return np.hstack([distribution.sample().cpu().detach().numpy()[:, :-1], mean.cpu().detach().numpy()[:, -1].reshape(-1, 1)])

	def ml_loss(self, ml_input, ml_target):

		with torch.no_grad():

			ml_input = torch.FloatTensor(ml_input).to(self.device)
			ml_target = torch.FloatTensor(ml_target).to(self.device)

			ml_mean, ml_log_std = self.forward(ml_input)
			loss = 0.5 * torch.sum(((ml_mean - ml_target) ** 2) / torch.exp(ml_log_std * 2) + ml_log_std * 2, dim = 1, keepdim = True)

			return loss.detach()

	def save_checkpoint(self, checkpoint_directory):
		print("... saving checkpoint ...")
		if not os.path.exists(checkpoint_directory):
			os.makedirs(checkpoint_directory)
		torch.save(self.state_dict(), checkpoint_directory + "/" + self.name + ".bin")
		# torch.save(self.state_dict(), "{}/{}_alpha_{}.bin".format(checkpoint_directory, self.name, self.alpha))

	def load_checkpoint(self, checkpoint_directory):
		print("... loading checkpoint ...")
		self.load_state_dict(torch.load(checkpoint_directory))

	def save_best(self, checkpoint_directory):
		print("... saving best checkpoint ...")
		if not os.path.exists(checkpoint_directory):
			os.makedirs(checkpoint_directory)
		torch.save(self.state_dict(), checkpoint_directory + "/" + self.name + "_best.bin")
		# torch.save(self.state_dict(), "{}/{}_alpha_{}_best.bin".format(checkpoint_directory, self.name, self.alpha))