import numpy as np
import torch
import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
#import mo_gymnasium as mo_gym
import argparse
import os

from tensorboardX import SummaryWriter
from gymnasium import spaces
from gymnasium.spaces import flatten_space

from utils import util, buffer
from agent.sac import sac_agent
from agent.vlsac import vlsac_agent
from agent.ctrlsac import ctrlsac_agent
from agent.rffsac import rffsac_agent
from agent.rffsac import rffsac_agent_bonus

EPS_GREEDY = 0.01

if __name__ == "__main__":
	
  parser = argparse.ArgumentParser()
  parser.add_argument("--dir", default=0, type=int)                     
  parser.add_argument("--alg", default="rffsac")                   # Alg name (sac, vlsac, spedersac, ctrlsac, random)
  parser.add_argument("--env", default="HalfCheetah-v5")          # Environment name
  parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
  parser.add_argument("--start_timesteps", default=25e3, type=float)# Time steps initial random policy is used
  parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
  parser.add_argument("--max_timesteps", default=1e6, type=float)   # Max time steps to run environment
  parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
  parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
  parser.add_argument("--hidden_dim", default=256, type=int)      # Network hidden dims
  parser.add_argument("--feature_dim", default=256, type=int)      # Latent feature dim
  parser.add_argument("--discount", default=0.99)                 # Discount factor
  parser.add_argument("--tau", default=0.005)                     # Target network update rate
  parser.add_argument("--learn_bonus", action="store_true")        # Save model and optimizer parameters
  parser.add_argument("--save_model", action="store_true")        # Save model and optimizer parameters
  parser.add_argument("--extra_feature_steps", default=3, type=int)
  parser.add_argument("--save_video", action="store_true", default=False)        # Save evaluation videos
  parser.add_argument("--bonus_coef", default=1.0, type=float)
  parser.add_argument("--bonus_lambda", default=1.0, type=float)
  parser.add_argument("--bonus_clip", default=1.0, type=float)
  args = parser.parse_args()

  
  if args.env == 'SlimHumanoid-v5':
    env = gym.make('Humanoid-v5', terminate_when_unhealthy=True)
    try:
      eval_env = gym.make('Humanoid-v5', render_mode="rgb_array")
    except TypeError:
      eval_env = gym.make('Humanoid-v5')
  else:
    if args.env[-3:] == "-ET":
      args.env = args.env[:-3]
      env = gym.make(args.env, terminate_when_unhealthy=True)
    else:
      env = gym.make(args.env)
    try:
      eval_env = gym.make(args.env, render_mode="rgb_array")
    except TypeError:
      eval_env = gym.make(args.env)

  class RewardWrapper(gym.RewardWrapper):
    def reward(self, reward):
        if args.env == 'HalfCheetah-v5':
            return reward
        elif args.env == 'Pendulum-v1':
            th, thdot = self.env.unwrapped.state
            u = self.env.unwrapped.last_u
            return -np.cos(th) - 0.1 * np.sin(th) - 0.1 * (thdot ** 2) - 0.001 * (u ** 2)
        elif args.env == 'Acrobot-v1':
            th1, th2, _, _, _, _ = self.env.unwrapped.state
            return -np.cos(th1) - np.cos(th1 + th2)
        elif args.env in ['Reacher-v2', 'Reacher-v5']:
            # return reward
            obs = self.env.unwrapped._get_obs()
            vec = np.asarray(obs[-2:], dtype=np.float64)
            dist = np.linalg.norm(vec)
            a = np.asarray(self.env.unwrapped.data.ctrl, dtype=np.float64)
            return -dist - np.sum(a ** 2)
        elif args.env == 'Humanoid-v5':
            data = self.env.unwrapped.data
            xvel = data.qvel[0]
            z = data.qpos[2]
            a = data.ctrl
            impact = np.sum(np.square(data.cfrc_ext))
            reward_alive = 5 * (1.0 <= z <= 2.0)
            return (50/3) * xvel - 0.1 * np.sum(a**2) - 5e-6 * impact + reward_alive
        elif args.env == 'SlimHumanoid-v5':
            data = self.env.unwrapped.data
            xvel = data.qvel[0]
            z = data.qpos[2]
            a = np.asarray(data.ctrl, dtype=np.float64)
            reward_alive = 5 * ((1.0 <= z) and (z <= 2.0))
            return (50/3) * xvel - 0.1 * np.sum(a**2) + reward_alive
        elif args.env == 'Hopper-v5':
            data = self.env.unwrapped.data
            xvel = data.qvel[0]
            z = data.qpos[1]
            a = np.asarray(data.ctrl, dtype=np.float64)
            return xvel - 0.1 * np.sum(a**2) - 3.0 * (z - 1.3)**2
        elif args.env == "InvertedPendulum-v5":
            obs = self.env.unwrapped._get_obs()
            theta = float(obs[1])
            return -(theta ** 2)
        else: # include Swimmer
            return reward
  env = RewardWrapper(env)
  env = RewardWrapper(env)
  eval_env = RewardWrapper(eval_env)
  eval_env = RewardWrapper(eval_env)
  
  #env.seed(args.seed)
  state, _ = env.reset(seed=args.seed)
  #eval_env.seed(args.seed)
  state_eval, _ = eval_env.reset(seed=args.seed)
  #max_length = env.env._max_episode_steps

  # setup log 
  log_path = f'log/{args.env}/{args.alg}/{args.dir}/{args.seed}'
  summary_writer = SummaryWriter(log_path)

  # set seeds
  torch.manual_seed(args.seed)
  np.random.seed(args.seed)

  # 
  if args.alg == 'random':
    args.start_timesteps = args.max_timesteps
  else:
    args.start_timesteps = min(args.max_timesteps // 10, args.start_timesteps)
  if args.alg == 'random':
    pass # keep eval freq as is
  else:
    args.eval_freq = min(args.eval_freq, args.max_timesteps // 10)

  # 
  obs_space = env.observation_space
  needs_flatten = isinstance(obs_space, spaces.Dict)
  if needs_flatten:
    flat_space = flatten_space(obs_space)
    state_dim = flat_space.shape[0]
  else:
    state_dim = obs_space.shape[0]
  
  def convert_obs(obs):
    if needs_flatten:
      return spaces.flatten(obs_space, obs)
    else:
      return obs
  
  if isinstance(env.action_space, gym.spaces.Box):
    max_action = float(env.action_space.high[0])
    action_dim = env.action_space.shape[0]
    discrete = False
  elif isinstance(env.action_space, gym.spaces.Discrete):
    max_action = env.action_space.n - 1   # optional placeholder
    action_dim = env.action_space.n
    discrete = True
  else:
    raise ValueError(f"Unsupported action space type: {type(env.action_space)}")

  if args.env in ['InvertedPendulum-v5']:
    lr = 6e-5
  elif args.env in ['Humanoid-v5', 'Pendulum-v1']:
    lr = 1e-6
  elif args.env in ['Hopper-v5']:
    lr = 1e-3
  else:
    lr = 3e-4
  

  kwargs = {
    "state_dim": state_dim,
    "action_dim": action_dim,
    "action_space": env.action_space,
    "discount": args.discount,
    "tau": args.tau,
    "hidden_dim": args.hidden_dim,
    "lr": lr
  }

  # Initialize policy
  if args.alg == "sac" or args.alg == "random":
    agent = sac_agent.SACAgent(**kwargs)
  elif args.alg == 'vlsac':
    kwargs['extra_feature_steps'] = args.extra_feature_steps
    kwargs['feature_dim'] = args.feature_dim
    agent = vlsac_agent.VLSACAgent(**kwargs)
  elif args.alg == 'ctrlsac':
    kwargs['extra_feature_steps'] = args.extra_feature_steps
    # hardcoded for now
    kwargs['feature_dim'] = 2048  
    kwargs['hidden_dim'] = 1024
    agent = ctrlsac_agent.CTRLSACAgent(**kwargs)
  elif args.alg == 'rffsac':
    kwargs['extra_feature_steps'] = args.extra_feature_steps
    # hardcoded for now
    kwargs['feature_dim'] = 1024
    kwargs['hidden_dim'] = 1024
    agent = rffsac_agent.RFFSACAgent(**kwargs)
  elif args.alg == 'rffsac_bonus':
    kwargs['extra_feature_steps'] = args.extra_feature_steps
    # hardcoded for now
    kwargs['feature_dim'] = 1024
    kwargs['hidden_dim'] = 1024
    agent = rffsac_agent_bonus.RFFSACAgentBonus(
      bonus_coef=args.bonus_coef,
      bonus_lambda=args.bonus_lambda * args.bonus_coef,
      bonus_clip=args.bonus_clip * args.bonus_coef,
      **kwargs,
    )
  
  replay_buffer = buffer.ReplayBuffer(state_dim, action_dim)

  # Evaluate untrained policy
  if args.save_video:
    evaluation = util.eval_policy(agent, eval_env, video_path=os.path.join(log_path, 'videos', 'eval_0.mp4'))
  else:
    evaluation = util.eval_policy(agent, eval_env)
  log_info = {}
  log_info['evaluation'] = evaluation
  summary_writer.add_scalar(f'info/evaluation', log_info['evaluation'], 0)


  state, _ = env.reset()
  state = convert_obs(state)
  action = agent.select_action(state, explore=True)
  done = False
  episode_reward = 0
  episode_timesteps = 0
  episode_num = 0
  timer = util.Timer()
  train_info = {}

  for t in range(int(args.max_timesteps)):
    
    episode_timesteps += 1

    # Select action randomly or according to policy
    if t < args.start_timesteps:
      action = env.action_space.sample()
    else:
      # action = agent.select_action(state, explore=True)
      # epsilon greedy as mentioned in the CTRL paper
      if np.random.uniform(0, 1) < EPS_GREEDY:
        action = env.action_space.sample()
      else:
        if torch.isnan(torch.tensor(state)).any():
          print("Warning: state contains NaN values.")
        action = agent.select_action(state, explore=True)

    # Perform action
    next_state, reward, terminated, truncated, _ = env.step(action) 
    next_state = convert_obs(next_state)
    done = terminated or truncated
    #done_bool = float(done) if episode_timesteps < max_length else 0

    if t >= args.start_timesteps:
      if hasattr(agent, "update_precision_matrix"):
        agent.update_precision_matrix(state, action)
        # if hasattr(agent, 'save_precision_metrics_plot'):
        #   agent.save_precision_metrics_plot('precision_metrics.png')

    # Store data in replay buffer
    replay_buffer.add(state, action, next_state, reward, done)

    state = next_state
    episode_reward += reward
    
    # Train agent after collecting sufficient data
    if t >= args.start_timesteps:
      train_info = agent.train(replay_buffer, batch_size=args.batch_size)

    if done: 
      # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
      if (episode_num + 1) % 100 == 0:
        print(f"Total T: {t+1} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}")
      # Reset environment
      state, _ = env.reset()
      state = convert_obs(state)
      done = False
      episode_reward = 0
      episode_timesteps = 0
      episode_num += 1 

    # Evaluate episode
    if (t + 1) % args.eval_freq == 0:
      steps_per_sec = timer.steps_per_sec(t+1)
      if args.save_video:
        evaluation = util.eval_policy(agent, eval_env, video_path=os.path.join(log_path, 'videos', f'eval_{t+1}.mp4'))
      else:
        evaluation = util.eval_policy(agent, eval_env)
      log_info = dict(train_info) if t >= args.start_timesteps else {}
      log_info['evaluation'] = evaluation
      for key, value in log_info.items():
        summary_writer.add_scalar(f'info/{key}', value, t+1)
      summary_writer.flush()

      print('Step {}. Steps per sec: {:.4g}.'.format(t+1, steps_per_sec))

  
  

  summary_writer.close()

  print('Total time cost {:.4g}s.'.format(timer.time_cost()))
