import os
import time
import gymnasium as gym
import numpy as np 
import torch
import matplotlib.pyplot as plt
import imageio.v2 as imageio

from torch import nn
from torch.nn import functional as F


def unpack_batch(batch):
  return batch.state, batch.action, batch.next_state, batch.reward, batch.done


class Timer:

	def __init__(self):
		self._start_time = time.time()
		self._step_time = time.time()
		self._step = 0

	def reset(self):
		self._start_time = time.time()
		self._step_time = time.time()
		self._step = 0

	def set_step(self, step):
		self._step = step
		self._step_time = time.time()

	def time_cost(self):
		return time.time() - self._start_time

	def steps_per_sec(self, step):
		sps = (step - self._step) / (time.time() - self._step_time)
		self._step = step
		self._step_time = time.time()
		return sps


def eval_policy(policy, eval_env, eval_episodes=50, video_path=None, record_episodes=1, fps=30):
  """
  Eval a policy
  """
  avg_reward = 0.
  capture_video = video_path is not None and record_episodes > 0
  frames = [] if capture_video else None
  for episode_idx in range(eval_episodes):
    state, _ = eval_env.reset()
    done = False
    is_recording = capture_video and episode_idx < record_episodes
    if is_recording:
      frame = eval_env.render()
      if frame is not None:
        frames.append(frame)
    while not done:
      action = policy.select_action(np.array(state))
      state, reward, terminated, truncated, *_ = eval_env.step(action)
      done = terminated or truncated
      avg_reward += reward
      if is_recording:
        frame = eval_env.render()
        if frame is not None:
          frames.append(frame)

  avg_reward /= eval_episodes

  print("---------------------------------------")
  print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
  print("---------------------------------------")
  if capture_video and frames:
    save_video(frames, video_path, fps=fps)
  return avg_reward



def weight_init(m):
	"""Custom weight init for Conv2D and Linear layers."""
	if isinstance(m, nn.Linear):
		nn.init.orthogonal_(m.weight.data)
		if hasattr(m.bias, 'data'):
			m.bias.data.fill_(0.0)


class MLP(nn.Module):
	def __init__(self,
								input_dim,
								hidden_dim,
								output_dim,
								hidden_depth,
								output_mod=None):
		super().__init__()
		self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth,
											output_mod)
		self.apply(weight_init)

	def forward(self, x):
		return self.trunk(x)


def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
	if hidden_depth == 0:
		mods = [nn.Linear(input_dim, output_dim)]
	else:
		mods = [nn.Linear(input_dim, hidden_dim), nn.ELU(inplace=True)]
		for i in range(hidden_depth - 1):
			mods += [nn.Linear(hidden_dim, hidden_dim), nn.ELU(inplace=True)]
		mods.append(nn.Linear(hidden_dim, output_dim))
	if output_mod is not None:
		mods.append(output_mod)
	trunk = nn.Sequential(*mods)
	return trunk

def to_np(t):
	if t is None:
		return None
	elif t.nelement() == 0:
		return np.array([])
	else:
		return t.cpu().detach().numpy()

def plot_evals(evaluations, filename='evals.png'):
	plt.figure()
	plt.plot(evaluations)
	plt.xlabel('Evaluation Number')
	plt.ylabel('Evaluation Average Reward')
	plt.title('Policy Evaluation over Time')
	plt.savefig(filename)
	plt.close()

def save_video(frames, filename='video.mp4', fps=30):
	if not frames:
		return
	directory = os.path.dirname(filename)
	if directory:
		os.makedirs(directory, exist_ok=True)
	processed_frames = []
	for frame in frames:
		array = np.asarray(frame)
		if array.dtype != np.uint8:
			array = np.clip(array, 0, 255).astype(np.uint8)
		processed_frames.append(array)
	imageio.mimsave(filename, processed_frames, fps=fps)