import gymnasium as gym
import numpy as np
import pandas as pd
from TD3.utils import *
import torch
import argparse
import os
import TD3.TD3 as TD3


class ObservationNoiseWrapper(gym.ObservationWrapper):
    def __init__(self, env, noise_std=0.01):
        super(ObservationNoiseWrapper, self).__init__(env)
        self.noise_std = noise_std

    def observation(self, observation):
        # Add Gaussian noise to the observation
        noise = np.random.normal(0, self.noise_std, size=observation.shape)
        return observation + noise




class ProcessNoiseWrapper(gym.Wrapper):
    def __init__(self, env, noise_std_qpos=0.005, noise_std_qvel=0.005):
        super(ProcessNoiseWrapper, self).__init__(env)
        self.noise_std_qpos = noise_std_qpos  # Standard deviation of noise in qpos (position)
        self.noise_std_qvel = noise_std_qvel  # Standard deviation of noise in qvel (velocity)

        # Check for required MuJoCo attributes
        assert hasattr(self.env, 'data'), "The environment must have a `data` attribute for qpos/qvel."
        assert hasattr(self.env.data, 'qpos') and hasattr(self.env.data, 'qvel'), \
            "The environment's `data` must have `qpos` and `qvel` attributes."

    def step(self, action):
        # Perform the original step
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Apply noise to the internal state
        qpos_noise = np.random.normal(0, self.noise_std_qpos, size=self.env.data.qpos.shape)
        qpos_noise = np.clip(qpos_noise, -3 * self.noise_std_qpos, 3 * self.noise_std_qpos)

        qvel_noise = np.random.normal(0, self.noise_std_qvel, size=self.env.data.qvel.shape)
        qvel_noise = np.clip(qvel_noise, -3 * self.noise_std_qvel, 3 * self.noise_std_qvel)

        # Update the internal state with noise
        self.env.data.qpos[:] += qpos_noise
        self.env.data.qvel[:] += qvel_noise

        new_obs = self.env.unwrapped._get_obs()

        # Fetch the updated observation

        return new_obs, reward, terminated, truncated, info


		
class DynamicsDataCollector:
	def __init__(self, env, env_addons, agent, history_size=5, prediction_horizon=3):
		
		self.env = env #gym.make(env_name)
			
		if(env_addons == "Observation-Noise"):
			self.env = ObservationNoiseWrapper(env)
		elif(env_addons == "Process-Noise"):
			self.env = ProcessNoiseWrapper(env)

		self.agent = agent
		self.history_size = history_size  # 'q' the most recent observations
		self.prediction_horizon = prediction_horizon  # 'h' timesteps to predict in the future
		self.history_buffer = []  # Stores recent 'q' observations
		self.data = []  # Stores the dynamics data
	
	def reset(self):
		state = self.env.reset(seed=0)[0]
		self.history_buffer = []  # Start history with the initial state
		return state
	
	def step(self, action):
		next_state, reward, terminated, truncated, info = self.env.step(action)
		done = terminated or truncated
		return next_state, reward, done, info
	
	def collect_dynamics_data(self, data_size=1):
		while len(self.data) < data_size: #episode in range(num_episodes):
			state = self.env.reset()[0]
			done = False
			episode_data = []
			i = 0
			while not done:
				# Select action from TD3 agent
				
				action = self.agent.select_action(state)
				
				# Take action in the environment
				next_state, reward, done, info = self.step(action)
				# Store dynamics information in the buffer
				self.history_buffer.append((state, action, reward, next_state))
				
				# Keep the buffer size to the most recent 'history_size' + 'prediction_horizon'
				if len(self.history_buffer) > self.history_size + self.prediction_horizon:
					self.history_buffer.pop(0)
				
				# If we have enough data for both history and future steps
				if i%10 == 0 and len(self.history_buffer) == self.history_size + self.prediction_horizon:
					trajectory = np.concatenate([np.concatenate([obs[0], obs[1]]) for obs in self.history_buffer])
					self.data.append(trajectory)
				# Move to the next state
				state = next_state
				i += 1
			print("Episode length:", i)
			self.history_buffer = []
			i=0
			# Save the episode's data
		self.data = self.data[:data_size]
		self.data = np.array(self.data)
	
	def save_data(self, filename):
		# Save the dynamics data as a CSV for future training
		np.random.shuffle(self.data)
		np.save(filename, self.data)

if __name__ == "__main__":
	
	parser = argparse.ArgumentParser()
	parser.add_argument("--policy", default="TD3")				  # Policy name (TD3, DDPG or OurDDPG)
	parser.add_argument("--env", default="HalfCheetah-v4")		  # OpenAI gym environment name
	parser.add_argument("--seed", default=0, type=int)			  # Sets Gym, PyTorch and Numpy seeds
	parser.add_argument("--start_timesteps", default=25e3, type=int)# Time steps initial random policy is used
	parser.add_argument("--eval_freq", default=5e3, type=int)	   # How often (time steps) we evaluate
	parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment
	parser.add_argument("--expl_noise", default=0.1, type=float)	# Std of Gaussian exploration noise
	parser.add_argument("--batch_size", default=256, type=int)	  # Batch size for both actor and critic
	parser.add_argument("--discount", default=0.99, type=float)	 # Discount factor
	parser.add_argument("--tau", default=0.005, type=float)		 # Target network update rate
	parser.add_argument("--policy_noise", default=0.2)			  # Noise added to target policy during critic update
	parser.add_argument("--noise_clip", default=0.5)				# Range to clip target policy noise
	parser.add_argument("--policy_freq", default=2, type=int)	   # Frequency of delayed policy updates
	#parser.add_argument("--save_model", action="store_true")	   # Save model and optimizer parameters
	parser.add_argument("--load_model", default="default")		     # Model load file name, "" doesn't load, "default" uses file_name
	parser.add_argument("--dataset_type", required=True, help="train or test data")        	 # train or test data
	parser.add_argument("--env_addons", required=True, help="Normal, Process-Noise, or Observation-Noise")  # train or test data
	parser.add_argument("--dataset-size", required=True, type=int)
	parser.add_argument("--device", required=True)
	args = parser.parse_args()

	file_name = f"{args.policy}_{args.env}_0"
	print("---------------------------------------")
	print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}")
	print("---------------------------------------")

	if not os.path.exists("./datasets_dir"):
		os.makedirs("./datasets_dir")

	env = gym.make(args.env)

	# Set seeds
	#env.seed(args.seed)
	state, done = env.reset(seed=args.seed)
	env.action_space.seed(args.seed)
	torch.manual_seed(args.seed)
	np.random.seed(args.seed)
	
	state_dim = env.observation_space.shape[0]
	action_dim = env.action_space.shape[0] 
	max_action = float(env.action_space.high[0])

	kwargs = {
		"state_dim": state_dim,
		"action_dim": action_dim,
		"max_action": max_action,
		"discount": args.discount,
		"tau": args.tau,
		"device": args.device,
	}

	# Initialize policy
	if args.policy == "TD3":
		# Target policy smoothing is scaled wrt the action scale
		kwargs["policy_noise"] = args.policy_noise * max_action
		kwargs["noise_clip"] = args.noise_clip * max_action
		kwargs["policy_freq"] = args.policy_freq
		policy = TD3.TD3(**kwargs)
	elif args.policy == "OurDDPG":
		policy = OurDDPG.DDPG(**kwargs)
	elif args.policy == "DDPG":
		policy = DDPG.DDPG(**kwargs)

	#if args.load_model != "" or args.collect_data:
	policy_file = file_name if args.load_model == "default" else args.load_model
	policy.load(f"./models/{policy_file}")
	print("---------------------------------------")
	print("TD3 policy loaded successfully...")
	print("---------------------------------------")

	# Initialize the data collector
	collector = DynamicsDataCollector(env, args.env_addons, policy, history_size=50, prediction_horizon=300)

	# Collect dynamics data from 100 episodes
	collector.collect_dynamics_data(data_size=args.dataset_size)

	dataset_dir = os.path.join(os.path.join("./datasets_dir", args.env), args.env_addons)
	os.makedirs(dataset_dir, exist_ok=True)
	# Save the data to a CSV file
	collector.save_data(os.path.join(dataset_dir, args.dataset_type+".npy"))
