import gymnasium as gym
import torch
import numpy as np
import random
import pickle
import SPReD as RL
import argparse

# Normalise the inputs (states and goals)
clip_range = 5
clip_obs = 200
clip_return = 50

def process_inputs(o, g, o_mean, o_std, g_mean, g_std, ax=0):
    o_clip = np.clip(o, -clip_obs, clip_obs)
    g_clip = np.clip(g, -clip_obs, clip_obs)
    o_norm = np.clip((o_clip - o_mean) / (o_std + 1e-6), -clip_range, clip_range)
    g_norm = np.clip((g_clip - g_mean) / (g_std + 1e-6), -clip_range, clip_range)
    inputs = np.concatenate([o_norm, g_norm], axis=ax)
    inputs = torch.tensor(inputs, dtype=torch.float32)
    return inputs

# Update mean and variance when a new input is added using Welford's algorithm
def update(existingAggregate, newValue):
    (count, mean, M2) = existingAggregate
    count += 1
    delta = newValue - mean
    mean += delta / count
    delta2 = newValue - mean
    M2 += delta * delta2
    return (count, mean, M2)

def finalize(existingAggregate):
    (count, mean, M2) = existingAggregate
    if count < 2:
        return (mean, np.ones_like(mean), np.ones_like(mean))
    else:
        (mean, variance, sampleVariance) = (mean, M2 / count, M2 / (count - 1))
        return (mean, variance, sampleVariance)

########################################################################

# Environment and hyperparameter settings
# The choice of hyperparameters is according to the original Q-filter and TD3 paper
parser = argparse.ArgumentParser()

parser.add_argument('--method', type=str, default="SPReDP") # Available methods are "EnsQfilter", "SPReDP", "SPReDE", "Nonpara_pairwise" and "Nonpara_cross"
parser.add_argument('--env', type=str, default="FetchPickAndPlace-v2") # Environment name
parser.add_argument("--seed", default=1, type=int) # Seed for reproduction
parser.add_argument("--offset", default=100, type=int) # Get a different seed for the evaluation environment
parser.add_argument('--test', type=str, default="Main") # Can be "Main", "DemoQuality", "DemoSize"
parser.add_argument("--ensemble_size", default=10, type=int) # Number of critic networks
parser.add_argument("--max_steps", default=4e6, type=lambda x: int(float(x))) # Total number of steps for running
parser.add_argument("--memory_size", default=1e6, type=lambda x: int(float(x))) # Memory size of the replay buffer
parser.add_argument("--learning_starts", default=10*1024, type=int) # Number of steps before the training starts
parser.add_argument("--episodes_eval", default=25, type=int) # Number of episodes we evaluate each time
parser.add_argument("--eval_freq", default=100, type=int) # The frequency of evaluation (once per 100 episodes)
parser.add_argument("--batch_size_buffer", default=1024, type=int) # Batch size from the replay buffer
parser.add_argument("--batch_size_demo", default=128, type=int) # Batch size from the demonstration buffer
parser.add_argument("--lambda1", default=0.01, type=float) # Weight of the policy improvement in actor loss
parser.add_argument("--lambda2", default=1/128, type=float) # Weight of the imitation in actor loss
parser.add_argument("--gamma", default=0.98, type=float) # Discount factor
parser.add_argument("--tau", default=0.005, type=float) # Rate of target network updates
parser.add_argument('--lr', type=float, default=1e-3) # Learning rate of actor and critic networks
parser.add_argument("--policy_noise", default=0.2) # Noise added to actions in critic updates
parser.add_argument("--noise_clip", default=0.5) # Clip the noise added to actions
parser.add_argument("--policy_freq", default=2, type=int) # The frequency of actor updates (once per 2 critic updates)
parser.add_argument('--device', type=str, default="cuda:0") 
parser.add_argument('--demo_quality', type=str, default="") # If you set test=DemoQuality, use "_expert", "_suboptimal" or "_poor"
parser.add_argument('--demo_size', type=str, default="") # If you set test=DemoSize, use "_demosize5", "_demosize10", "_demosize20", "_demosize50" or "_demosize100"
args = parser.parse_args()

# Load the environment
env = gym.make(args.env)
env_train = gym.make(args.env)
env_eval = gym.make(args.env)

# Set seeds for reproducibility
env.reset(seed=args.seed)
env.action_space.seed(args.seed)
env_train.reset(seed=args.seed)
env_train.action_space.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

state_dim = env.observation_space['observation'].shape[0]
goal_dim = env.observation_space['desired_goal'].shape[0]
obs_dim = state_dim + goal_dim
action_dim = env.action_space.shape[0]
max_action = env.action_space.high[0]

# Load the pre-collected demonstrations
# You can specify the quantity and quality of demonstrations you want for which test
open_file = open(f"Demonstrations/{args.test}/{args.env}{args.demo_quality}{args.demo_size}.pkl", "rb")
dataset = pickle.load(open_file)
open_file.close()
demos = []
# Use the mean and variance of demonstration inputs as the initialisation
states_agg = (0, np.zeros(state_dim), np.zeros(state_dim))
goals_agg = (0, np.zeros(goal_dim), np.zeros(goal_dim))
for i in range(len(dataset)):
    demos.append((dataset[i][0], dataset[i][1], dataset[i][2], dataset[i][3], dataset[i][4], dataset[i][5]))
    states_agg = update(states_agg, np.array(dataset[i][0]))
    goals_agg = update(goals_agg, np.array(dataset[i][4]))

replay_buffer = []
score_history = []
success_history = []
average_accept_demos = []
steps = 0
episodes = 0

agent = RL.Agent(state_dim, goal_dim, action_dim, max_action, hidden_dim=(256, 256),
                  method=args.method, ensemble_size=args.ensemble_size, lambda1=args.lambda1,
                  lambda2=args.lambda2, batch_size_buffer=args.batch_size_buffer, batch_size_demo=args.batch_size_demo,
                  gamma=args.gamma, tau=args.tau, lr=args.lr,
                  policy_noise=0.2, noise_clip=0.5, policy_freq=2, device=args.device)

while steps < args.max_steps:
    # Training #
    done = False
    obs_ep = []
    obs = env_train.reset()[0]
    state = obs['observation']
    desired_goal = obs['desired_goal']
    # Update the initial state and goal
    states_agg = update(states_agg, np.array(state))
    goals_agg = update(goals_agg, np.array(desired_goal))

    # Interact with the environment to save transitions in the replay buffer
    while not done:
        state_stats = finalize(states_agg)
        goal_stats = finalize(goals_agg)
        inputs = process_inputs(state, desired_goal, o_mean=state_stats[0], o_std=np.sqrt(state_stats[1]),
                                g_mean=goal_stats[0], g_std=np.sqrt(goal_stats[1]))
        action = agent.choose_action(inputs)
        noise = np.random.normal(0, max_action * 0.1, size=action_dim)
        action = np.clip(action + noise, -max_action, max_action)
        next_obs, reward, terminated, truncated, info = env_train.step(action)
        next_state = next_obs['observation']
        next_desired_goal = next_obs['desired_goal']
        states_agg = update(states_agg, np.array(next_state))
        goals_agg = update(goals_agg, np.array(desired_goal))
        done = terminated or truncated
        replay_buffer.append((state, action, reward, next_state, desired_goal, done))
        steps += 1

        if len(replay_buffer) > args.memory_size:
            replay_buffer.pop(0)

        # Save the related information of transitions for HER later
        obs_ep.append((obs, action, next_obs, info))
        # Update observations and states
        obs = next_obs
        state = next_state

    # HER: save the transitions to the replay buffer again with the actually achieved goal after each episode
    substitute_goal = obs["achieved_goal"].copy()
    for i in range(len(obs_ep)):
        observation, action, next_observation, info = obs_ep[i]
        state = observation['observation']
        states_agg = update(states_agg, np.array(state))
        goals_agg = update(goals_agg, np.array(substitute_goal))
        obs = np.concatenate([state, substitute_goal])
        next_state = next_observation['observation']
        next_obs = np.concatenate([next_state, substitute_goal])
        # Compute the corresponding rewards and done flag with the actually achieved goal
        substitute_reward = env.unwrapped.compute_reward(observation["achieved_goal"], substitute_goal, info)
        substitute_terminated = env.unwrapped.compute_terminated(observation["achieved_goal"], substitute_goal, info)
        substitute_truncated = env.unwrapped.compute_truncated(observation["achieved_goal"], substitute_goal, info)
        substitute_done = substitute_terminated or substitute_truncated
        replay_buffer.append((state, action, substitute_reward, next_state, substitute_goal, substitute_done))
        if len(replay_buffer) > args.memory_size:
            replay_buffer.pop(0)

    if len(replay_buffer) > args.learning_starts:
        state_stats = finalize(states_agg)
        goal_stats = finalize(goals_agg)
        agent.train(replay_buffer, demos, normalizers=(state_stats[0], np.sqrt(state_stats[1]), goal_stats[0],
                                                         np.sqrt(goal_stats[1])), iterations=2)
    episodes += 1
    
    # Evaluation
    env_eval.reset(seed=args.seed + args.offset)
    env_eval.action_space.seed(args.seed + args.offset)
    if episodes % args.eval_freq == 0:
        score_temp = []
        fin_temp = []
        for e in range(args.episodes_eval):
            done_eval = False
            obs_eval = env_eval.reset()[0]
            state_eval = obs_eval['observation']
            desired_goal_eval = obs_eval['desired_goal']
            score_eval = 0
            while not done_eval:
                with torch.no_grad():
                    state_stats = finalize(states_agg)
                    goal_stats = finalize(goals_agg)
                    inputs = process_inputs(state_eval, desired_goal_eval, o_mean=state_stats[0],
                                            o_std=np.sqrt(state_stats[1]),
                                            g_mean=goal_stats[0], g_std=np.sqrt(goal_stats[1]))
                    action_eval = agent.choose_action(inputs)
                    obs_eval, reward_eval, terminated_eval, truncated_eval, info_eval = env_eval.step(action_eval)
                    done_eval = terminated_eval or truncated_eval
                    state_eval = obs_eval['observation']
                    desired_goal_eval = obs_eval['desired_goal']
                    score_eval += reward_eval
            # Save the information of success and cumulative score after each evaluated episode
            fin_eval = info_eval['is_success']
            score_temp.append(score_eval)
            fin_temp.append(fin_eval)
        # Record the average score and success rate for episodes_eval=25 episodes
        score_eval = np.mean(score_temp)
        fin_eval = np.mean(fin_temp)
        score_history.append(score_eval)
        success_history.append(fin_eval)
        print("Episode", episodes, "Env Steps", steps, "Score %.2f" % score_eval, "Success rate %.2f" % fin_eval)
        # Record the average of last 10 acceptance percentage to get smooth results
        if args.method == "EnsQfilter":
            if len(agent.accept_history) == 0:
                print("Acceptance Rate of Demos = 0 ")
            else:
                last_ten_percent_demos = agent.accept_history[-10:] if len(
                    agent.accept_history) > 10 else agent.accept_history
                average_accept_demos.append(np.mean(last_ten_percent_demos))
                print("Acceptance Rate of Demos = %.2f " % (np.mean(last_ten_percent_demos)))

# Save score, success rate and imitation percentage or weight in files
np.save(f"Results/{args.env}{args.demo_quality}{args.demo_size}/{args.method}/EnsSize_{args.ensemble_size}_S{args.seed}_score", score_history)
np.save(f"Results/{args.env}{args.demo_quality}{args.demo_size}/{args.method}/EnsSize_{args.ensemble_size}_S{args.seed}_success", success_history)
if args.method == "EnsQfilter":
    np.save(f"Results/{args.env}{args.demo_quality}{args.demo_size}/{args.method}/EnsSize_{args.ensemble_size}_S{args.seed}_accept", average_accept_demos)
else:
    np.save(f"Results/{args.env}{args.demo_quality}{args.demo_size}/{args.method}/EnsSize_{args.ensemble_size}_S{args.seed}_weight",
            agent.weight_history)
