import numpy as np
import math
import torch
import time

def check(input):
    if type(input) == np.ndarray:
        return torch.from_numpy(input)
    return input
        
def get_gard_norm(it):
    sum_grad = 0
    for x in it:
        if x.grad is None:
            continue
        sum_grad += x.grad.norm() ** 2
    return math.sqrt(sum_grad)

def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
    """Decreases the learning rate linearly"""
    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def huber_loss(e, d):
    a = (abs(e) <= d).float()
    b = (abs(e) > d).float()
    return a*e**2/2 + b*d*(abs(e)-d/2)

def mse_loss(e):
    return e**2/2

def get_shape_from_obs_space(obs_space):
    if obs_space.__class__.__name__ == 'Box':
        obs_shape = obs_space.shape
    elif obs_space.__class__.__name__ == 'list':
        obs_shape = obs_space
    else:
        raise NotImplementedError
    return obs_shape

def get_shape_from_act_space(act_space):
    if act_space.__class__.__name__ == 'Discrete':
        act_shape = 1
        latent_act_shape = act_space.n
    elif act_space.__class__.__name__ == "MultiDiscrete":
        act_shape = act_space.shape[0]
        latent_act_shape = act_space.nvec.sum()
    elif act_space.__class__.__name__ == "Box":
        act_shape = act_space.shape[0]
        latent_act_shape = act_space.shape[0]
    elif act_space.__class__.__name__ == "MultiBinary":
        act_shape = act_space.shape[0]
    elif isinstance(act_space, list):
        act_shape = 0
        latent_act_shape = 0
        for _act_space in act_space:
            if _act_space.__class__.__name__ == "Box":
                act_shape += _act_space.shape[0]
                latent_act_shape += _act_space.shape[0]
            elif _act_space.__class__.__name__ == "Discrete":
                act_shape += 1
                latent_act_shape += _act_space.n
            elif _act_space.__class__.__name__ == "MultiDiscrete":
                act_shape += _act_space.shape[0]
                latent_act_shape += _act_space.nvec.sum()
            else:
                raise RuntimeError(f"{_act_space.__class__.__name__} is not supported")
    return act_shape, latent_act_shape


def tile_images(img_nhwc):
    """
    Tile N images into one big PxQ image
    (P,Q) are chosen to be as close as possible, and if N
    is square, then P=Q.
    input: img_nhwc, list or array of images, ndim=4 once turned into array
        n = batch index, h = height, w = width, c = channel
    returns:
        bigim_HWc, ndarray with ndim=3
    """
    img_nhwc = np.asarray(img_nhwc)
    N, h, w, c = img_nhwc.shape
    H = int(np.ceil(np.sqrt(N)))
    W = int(np.ceil(float(N)/H))
    img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
    img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
    img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
    img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
    return img_Hh_Ww_c

def extend_buffer(buffer, extend_traj_length):
    from onpolicy.algorithms.gail.gail_utils import extend_trajectory

    '''
    self.share_obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, num_agents, *share_obs_shape),
                                dtype=np.float32)
    self.obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, num_agents, *obs_shape), dtype=np.float32)

    self.rnn_states = np.zeros(
        (self.episode_length + 1, self.n_rollout_threads, num_agents, self.recurrent_N, self.hidden_size),
        dtype=np.float32)
    self.rnn_states_critic = np.zeros(
        (self.episode_length + 1, self.n_rollout_threads, num_agents, self.recurrent_N, self.hidden_size),
        dtype=np.float32)

    self.value_preds = np.zeros(
        (self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
    self.returns = np.zeros(
        (self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
    self.advantages = np.zeros(
        (self.episode_length, self.n_rollout_threads, num_agents, 1), dtype=np.float32)

    self.available_actions = None

    self.actions = np.zeros(
        (self.episode_length, self.n_rollout_threads, num_agents, act_shape), dtype=np.float32)
    self.action_log_probs = np.zeros(
        (self.episode_length, self.n_rollout_threads, num_agents, act_shape), dtype=np.float32)
    self.rewards = np.zeros(
        (self.episode_length, self.n_rollout_threads, num_agents, 1), dtype=np.float32)

    self.masks = np.ones((self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
    self.bad_masks = np.ones((self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
    self.active_masks = np.ones((self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
    '''

    for k in ['share_obs', 'obs', 'rnn_states', 'rnn_states_critic', 'actions', 'action_log_probs', 'rewards', 'masks', 'bad_masks', 'active_masks']:
        e = getattr(buffer, k)
        e = torch.from_numpy(e)
        setattr(buffer, k, e)

    ret = []

    for b in range(buffer.n_rollout_threads):
        share_obs = check(buffer.share_obs[:-1, b])
        obs = check(buffer.obs[:-1, b])
        rnn_states = check(buffer.rnn_states[:-1, b])
        rnn_states_critic = check(buffer.rnn_states_critic[:-1, b])
        actions = check(buffer.actions[:, b])
        action_log_probs = check(buffer.action_log_probs[:, b])
        rewards = check(buffer.rewards[:, b])
        masks = check(buffer.masks[:-1, b])
        bad_masks = check(buffer.bad_masks[:-1, b])
        active_masks = check(buffer.active_masks[:-1, b])

        terminals = 1 - check(buffer.masks[1:, b])

        share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs, rewards, masks, bad_masks, active_masks = extend_trajectory(terminals, 
                         [share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs, rewards, masks, bad_masks, active_masks], 
                         extend_traj_length=extend_traj_length, 
                         fill_value=[0, 0, 0, 0, 0, 0, 0, 1, 1, 0])

        share_obs = torch.cat([share_obs, buffer.share_obs[-1:, b]], dim=0)
        obs = torch.cat([obs, buffer.obs[-1:, b]], dim=0)
        rnn_states = torch.cat([rnn_states, buffer.rnn_states[-1:, b]], dim=0)
        rnn_states_critic = torch.cat([rnn_states_critic, buffer.rnn_states_critic[-1:, b]], dim=0)
        masks = torch.cat([masks, buffer.masks[-1:, b]], dim=0)
        bad_masks = torch.cat([bad_masks, buffer.bad_masks[-1:, b]], dim=0)
        active_masks = torch.cat([active_masks, buffer.active_masks[-1:, b]], dim=0)

        ret.append([share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs, rewards, masks, bad_masks, active_masks])

    ret_lengths = [len(r[0]) for r in ret]
    max_length = max(ret_lengths)

    def extend(x, l, v):
        return torch.cat([x, torch.ones(l, *x.shape[1:]).to(x) * v], dim=0)
    
    _ret = []
    for r in ret:
        share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs, rewards, masks, bad_masks, active_masks = r
        fill_value = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
        r = [extend(x, max_length - len(r[0]), v) for x, v in zip(r, fill_value)]
        _ret.append(r)
    
    ret = [torch.stack([r[i] for r in _ret], dim=1).numpy() for i in range(len(_ret[0]))]
    buffer.share_obs, buffer.obs, buffer.rnn_states, buffer.rnn_states_critic, buffer.actions, buffer.action_log_probs, buffer.rewards, buffer.masks, buffer.bad_masks, buffer.active_masks = ret

    buffer.value_preds = np.zeros(
        (max_length, *buffer.value_preds.shape[1:]), dtype=np.float32)
    buffer.returns = np.zeros_like(buffer.value_preds)
    buffer.advantages = np.zeros(
        (max_length - 1, *buffer.advantages.shape[1:]), dtype=np.float32)

    return buffer



def print_banner(s, separator="-", num_star=60):
	print(separator * num_star, flush=True)
	print(s, flush=True)
	print(separator * num_star, flush=True)


class Progress:

	def __init__(self, total, name='Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100):
		self.total = total
		self.name = name
		self.ncol = ncol
		self.max_length = max_length
		self.indent = indent
		self.line_width = line_width
		self._speed_update_freq = speed_update_freq

		self._step = 0
		self._prev_line = '\033[F'
		self._clear_line = ' ' * self.line_width

		self._pbar_size = self.ncol * self.max_length
		self._complete_pbar = '#' * self._pbar_size
		self._incomplete_pbar = ' ' * self._pbar_size

		self.lines = ['']
		self.fraction = '{} / {}'.format(0, self.total)

		self.resume()

	def update(self, description, n=1):
		self._step += n
		if self._step % self._speed_update_freq == 0:
			self._time0 = time.time()
			self._step0 = self._step
		self.set_description(description)

	def resume(self):
		self._skip_lines = 1
		print('\n', end='')
		self._time0 = time.time()
		self._step0 = self._step

	def pause(self):
		self._clear()
		self._skip_lines = 1

	def set_description(self, params=[]):

		if type(params) == dict:
			params = sorted([
				(key, val)
				for key, val in params.items()
			])

		############
		# Position #
		############
		self._clear()

		###########
		# Percent #
		###########
		percent, fraction = self._format_percent(self._step, self.total)
		self.fraction = fraction

		#########
		# Speed #
		#########
		speed = self._format_speed(self._step)

		##########
		# Params #
		##########
		num_params = len(params)
		nrow = math.ceil(num_params / self.ncol)
		params_split = self._chunk(params, self.ncol)
		params_string, lines = self._format(params_split)
		self.lines = lines

		description = '{} | {}{}'.format(percent, speed, params_string)
		print(description)
		self._skip_lines = nrow + 1

	def append_description(self, descr):
		self.lines.append(descr)

	def _clear(self):
		position = self._prev_line * self._skip_lines
		empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)])
		print(position, end='')
		print(empty)
		print(position, end='')

	def _format_percent(self, n, total):
		if total:
			percent = n / float(total)

			complete_entries = int(percent * self._pbar_size)
			incomplete_entries = self._pbar_size - complete_entries

			pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries]
			fraction = '{} / {}'.format(n, total)
			string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent * 100))
		else:
			fraction = '{}'.format(n)
			string = '{} iterations'.format(n)
		return string, fraction

	def _format_speed(self, n):
		num_steps = n - self._step0
		t = time.time() - self._time0
		speed = num_steps / t
		string = '{:.1f} Hz'.format(speed)
		if num_steps > 0:
			self._speed = string
		return string

	def _chunk(self, l, n):
		return [l[i:i + n] for i in range(0, len(l), n)]

	def _format(self, chunks):
		lines = [self._format_chunk(chunk) for chunk in chunks]
		lines.insert(0, '')
		padding = '\n' + ' ' * self.indent
		string = padding.join(lines)
		return string, lines

	def _format_chunk(self, chunk):
		line = ' | '.join([self._format_param(param) for param in chunk])
		return line

	def _format_param(self, param):
		k, v = param
		return '{} : {}'.format(k, v)[:self.max_length]

	def stamp(self):
		if self.lines != ['']:
			params = ' | '.join(self.lines)
			string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed)
			self._clear()
			print(string, end='\n')
			self._skip_lines = 1
		else:
			self._clear()
			self._skip_lines = 0

	def close(self):
		self.pause()


class Silent:

	def __init__(self, *args, **kwargs):
		pass

	def __getattr__(self, attr):
		return lambda *args: None


class EarlyStopping(object):
	def __init__(self, tolerance=5, min_delta=0):
		self.tolerance = tolerance
		self.min_delta = min_delta
		self.counter = 0
		self.early_stop = False

	def __call__(self, train_loss, validation_loss):
		if (validation_loss - train_loss) > self.min_delta:
			self.counter += 1
			if self.counter >= self.tolerance:
				return True
		else:
			self.counter = 0
		return False
