import os
import math
import h5py
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm

from termination_functions import termination_function

##########################################################################

class Swish(nn.Module):

	def __init__(self):
		super().__init__()

	def forward(self, input):
		return input * torch.sigmoid(input)

##########################################################################

class Batcher():
    def __init__(self, batch_size, data):
        self.batch_size = batch_size
        self.data = data
        self.num_entries = len(data[0])
        self.num_batches = int(math.ceil(self.num_entries / self.batch_size))
        self.reset()

    def reset(self):
        self.batch_start = 0
        self.batch_end = self.batch_start + self.batch_size

    def end(self):
        return self.batch_start >= self.num_entries

    def next_batch(self):
        batch = []
        for d in self.data:
            batch.append(d[self.batch_start: self.batch_end])
        self.batch_start = self.batch_end
        self.batch_end = min(self.batch_start + self.batch_size, self.num_entries)
        return batch

    def shuffle(self):
        indices = np.arange(self.num_entries)
        np.random.shuffle(indices)
        self.data = [d[indices] for d in self.data]

##########################################################################

def init_weights(layer):

	""" Custom weight initialization for layers """

	if isinstance(layer, nn.Linear):
		nn.init.orthogonal_(layer.weight.data)

		if hasattr(layer.bias, "data"):
			layer.bias.data.fill_(0.0)

# ##########################################################################

def mlp(input_dim, hidden_dim, output_dim, hidden_depth,
	activation = nn.ReLU(inplace = True), output_mod = None):

	if hidden_depth == 0:
		mods = [nn.Linear(input_dim, output_dim)]

	else:
		mods = [nn.Linear(input_dim, hidden_dim), activation]
		for _ in range(hidden_depth - 1):
			mods += [nn.Linear(hidden_dim, hidden_dim), activation]
		mods.append(nn.Linear(hidden_dim, output_dim))

	if output_mod is not None:
		mods.append(nn.Linear(output_mod))

	model = nn.Sequential(*mods)

	return model

##########################################################################
""" RolloutGenerator has been slightly changed in order to run the code on cluster hpcs. This code needs to be double-checked. """

class RolloutGenerator():
	def __init__(self, dataset, env_name, dynamics_model, policy = None):

		self.policy = policy
		self.scaler = StandardScaler()
		self.env_name = env_name
		self.dynamics_model = dynamics_model

		self.observations = dataset["observations"]
		actions = dataset["actions"]
		scaler_data = np.concatenate([self.observations, actions], axis = 1)
		self.scaler.fit(scaler_data)

		self.observation_shape = dataset["observations"].shape[1]
		self.action_shape = dataset["actions"].shape[1]

	def sample(self, batch_size, length):
		# states = np.zeros((0, self.observation_shape), dtype = np.float32)
		# actions = np.zeros((0, self.action_shape), dtype = np.float32)
		# rewards = np.zeros((0, 1), dtype = np.float32)
		# next_states = np.zeros((0, self.observation_shape), dtype = np.float32)
		# terminals = np.zeros((0, 1), dtype = np.bool)

		# state = self.observations[np.random.choice(len(self.observations), batch_size)]
		# for _ in range(length):

		# 	if self.policy is None:
		# 		action = np.random.uniform(-1, 1, (len(state), self.action_shape))
		# 	else:
		# 		action = self.policy.select_action(state)

		# 	input = self.scaler.transform(np.concatenate([state, action], axis = 1))
		# 	""" The effect of deterministic prediction should be checked... """
		# 	prediction = self.dynamics_model.predict(input)
		# 	state_, reward = np.split(prediction, [self.observation_shape], axis = 1)

		# 	next_state = state + state_
		# 	done = termination_function(state, action, next_state, self.env_name.split("_")[0])

		# 	states = np.vstack([states, state])
		# 	actions = np.vstack([actions, action])
		# 	rewards = np.vstack([rewards, reward])
		# 	terminals = np.vstack([terminals, done])
		# 	next_states = np.vstack([next_states, next_state])

		# 	state += state_
		# 	state = state[~done.flatten()]

		# 	if len(state) == 0:
		# 		break

		# return states, actions, rewards, next_states, terminals

		states = np.zeros((0, self.observation_shape), dtype = np.float32)
		actions = np.zeros((0, self.action_shape), dtype = np.float32)
		rewards = np.zeros((0, 1), dtype = np.float32)
		next_states = np.zeros((0, self.observation_shape), dtype = np.float32)
		terminals = np.zeros((0, 1), dtype = np.bool)

		state = self.observations[np.random.choice(len(self.observations), batch_size)]
		for _ in range(length):

			if self.policy is None:
				action = np.random.uniform(-1, 1, (len(state), self.action_shape))
			else:
				action = self.policy.select_action(state[:, 1:])

			input = self.scaler.transform(np.concatenate([state, action], axis = 1))
			""" The effect of deterministic prediction should be checked... """
			prediction = self.dynamics_model.predict(input)
			# state_, reward = np.split(prediction, [self.observation_shape], axis = 1)
			reward, state_ = np.split(prediction, [1], axis = 1)

			next_state = state + prediction
			done = termination_function(state[:, 1:], action, next_state[:, 1:], self.env_name.split("_")[0])

			states = np.vstack([states, state])
			actions = np.vstack([actions, action])
			rewards = np.vstack([rewards, reward])
			terminals = np.vstack([terminals, done])
			next_states = np.vstack([next_states, next_state])

			state += prediction
			state = state[~done.flatten()]

			if len(state) == 0:
				break

		return states, actions, rewards, next_states, terminals

	def update_policy(self, policy):
		self.policy = policy

	def update_model(self, model):
		self.dynamics_model = model

##########################################################################

def meshgrid():
	pass

# ##########################################################################

class StandardScaler():
	
	def __init__(self):
		self.std = 0.0
		self.mean = 0.0
		self.fitted = False

	def fit(self, data):
		self.mean = np.mean(data, axis = 0, keepdims = True)
		self.std = np.std(data, axis = 0, keepdims = True)
		self.std[self.std < 1e-12] = 1
		self.fitted = True

	def transform(self, data):
		return (data - self.mean) / self.std

	def inverse_transform(self, data):
		return data * self.std + self.mean

	def get_vars(self):
		return [self.mean, self.std]

##########################################################################

########################## Dataset Utilities #############################

##########################################################################

def get_keys(h5file):
	keys = []

	def visitor(name, item):
		if isinstance(item, h5py.Dataset):
			keys.append(name)

	h5file.visititems(visitor)
	return keys

def get_dataset(dataset_name):
	dataset_path = os.environ.get('D4RL_DATASET_DIR', os.path.expanduser('~/.d4rl/datasets'))
	# dataset_path = "./datasets"
	h5path = os.path.join(dataset_path, dataset_name + ".hdf5")

	data_dict = {}
	with h5py.File(h5path, "r") as dataset_file:
		for k in tqdm(get_keys(dataset_file), desc = "load datafile"):
			try:
				data_dict[k] = dataset_file[k][:]
			except ValueError as e:
				data_dict[k]: dataset_file[k][()]

	for key in ['observations', 'actions', 'rewards', 'terminals']:
	    assert key in data_dict, 'Dataset is missing key %s' % key
	N_samples = data_dict['observations'].shape[0]
	# if self.observation_space.shape is not None:
	#     assert data_dict['observations'].shape[1:] == self.observation_space.shape, \
	#         'Observation shape does not match env: %s vs %s' % (
	#             str(data_dict['observations'].shape[1:]), str(self.observation_space.shape))
	# assert data_dict['actions'].shape[1:] == self.action_space.shape, \
	#     'Action shape does not match env: %s vs %s' % (
	#         str(data_dict['actions'].shape[1:]), str(self.action_space.shape))
	if data_dict['rewards'].shape == (N_samples, 1):
	    data_dict['rewards'] = data_dict['rewards'][:, 0]
	assert data_dict['rewards'].shape == (N_samples,), 'Reward has wrong shape: %s' % (
	    str(data_dict['rewards'].shape))
	if data_dict['terminals'].shape == (N_samples, 1):
	    data_dict['terminals'] = data_dict['terminals'][:, 0]
	assert data_dict['terminals'].shape == (N_samples,), 'Terminals has wrong shape: %s' % (
	    str(data_dict['rewards'].shape))
	return data_dict

def qlearning_dataset(dataset_name, terminate_on_end = False):
	dataset = get_dataset(dataset_name)

	N = dataset["rewards"].shape[0]
	obs_ = []
	next_obs_ = []
	action_ = []
	reward_ = []
	done_ = []

	use_timeouts = False
	if "timeouts" in dataset:
		use_timeouts = True

	episode_step = 0
	for i in range(N - 1):
		obs = dataset["observations"][i].astype(np.float32)
		new_obs = dataset["observations"][i + 1].astype(np.float32)
		action = dataset["actions"][i].astype(np.float32)
		reward = dataset["rewards"][i].astype(np.float32)
		done_bool = bool(dataset["terminals"][i])

		if use_timeouts:
			final_timestep = dataset["timeouts"][i]
		else:
			final_timestep = (episode_step == 1000 - 1)
		if (not terminate_on_end) and final_timestep:
			episode_step = 0
			continue
		if done_bool or final_timestep:
			episode_step = 0

		obs_.append(obs)
		next_obs_.append(new_obs)
		action_.append(action)
		reward_.append(reward)
		done_.append(done_bool)
		episode_step += 1

	return {
		"observations": np.array(obs_),
		"actions": np.array(action_),
		"next_observations": np.array(next_obs_),
		"rewards": np.array(reward_),
		"terminals": np.array(done_),
	}

def full_dataset(dataset_name, terminate_on_end = False):

	dataset = get_dataset(dataset_name)

	N = dataset["rewards"].shape[0]
	obs_ = []
	next_obs_ = []
	action_ = []
	reward_ = []
	done_ = []

	use_timeouts = False
	if "timeouts" in dataset:
		use_timeouts = True

	episode_step = 0
	for i in range(N - 1):
		qpos = dataset["infos/qpos"][i].astype(np.float32)
		qvel = dataset["infos/qvel"][i].astype(np.float32)
		obs = np.concatenate([qpos, qvel])

		new_qpos = dataset["infos/qpos"][i + 1].astype(np.float32)
		new_qvel = dataset["infos/qvel"][i + 1].astype(np.float32)
		new_obs = np.concatenate([new_qpos, new_qvel])
		
		action = dataset["actions"][i].astype(np.float32)
		reward = dataset["rewards"][i].astype(np.float32)
		done_bool = bool(dataset["terminals"][i])

		if use_timeouts:
			final_timestep = dataset["timeouts"][i]
		else:
			final_timestep = (episode_step == 1000 - 1)
		if (not terminate_on_end) and final_timestep:
			episode_step = 0
			continue
		if done_bool or final_timestep:
			episode_step = 0

		obs_.append(obs)
		next_obs_.append(new_obs)
		action_.append(action)
		reward_.append(reward)
		done_.append(done_bool)
		episode_step += 1

	return {
		"observations": np.array(obs_),
		"actions": np.array(action_),
		"next_observations": np.array(next_obs_),
		"rewards": np.array(reward_),
		"terminals": np.array(done_),
	}