import gym
import mujoco_py
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
plt.style.use('ggplot')

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dc', default = 0.3, type = float) 
parser.add_argument('--lr', default = 0.5, type = float) 
parser.add_argument('--lr_min', default = 0.01, type = float) 
parser.add_argument('--env', default = "Swimmer-v2", choices = ["HalfCheetah-v2", "Swimmer-v2", "Reacher-v2"])
parser.add_argument('--steps', default = 10000, type = int) 
parser.add_argument('--reps', default = 1, type = int) 

args = parser.parse_args()

envir_name = args.env # "Hopper-v2" # "HalfCheetah-v2" # "Swimmer-v2" "Reacher-v2"

def plot_hist(total_hist, hist_name):
    total_hist = np.array(total_hist)
    total_mean = total_hist.mean(0)
    total_std = total_hist.std(0)
    plt.plot(total_mean, "--", label=hist_name, linewidth=2, alpha=0.8)
    total_steps = np.arange(len(total_mean))
    plt.fill_between(total_steps, total_mean - total_std,total_mean + total_std, alpha=0.3)

    plot_hist(total_hist, "ZO-RankSGD")
    plt.ylabel("objective value")
    plt.xlabel("num query")
    plt.legend()
    plt.show()


def run():
    np.random.seed(1234)

    reps = args.reps # 10 independent runs, for plotting mean+-std
    total_steps = args.steps
    horizon = 1000

    # (m,k)-ranking oracle
    num_queries_per_step = 5
    top_k = 5
    smoothing_scale = args.lr
    lr = args.lr
    decay = args.dc
    lr_min = args.lr_min
    mm = 0.

    total_hist = []
    f = open('logs_smoothing_001_lr_{}_sm_{}_decay_{}_steps_{}_reps_{}_lr_min{}'.format(lr,smoothing_scale,decay,total_steps,reps,lr_min) + envir_name + '.txt', 'w')
    env = gym.make(envir_name)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    linear_predictor_init = np.random.normal(0, 1, size=(action_dim, state_dim))
    
    for _ in range(reps):
        print('rep number:', _)
        f.write('rep number: ' + str(_) + '\n')
        hist = []

        linear_predictor = np.copy(linear_predictor_init)
        momentum = None
        mode = "grad_est"
        for step in range(total_steps):
            learning_rate  = (lr * (total_steps - step) + lr_min * step)/total_steps
            sm = (smoothing_scale * (total_steps - step) + lr_min * step)/total_steps

            if mode == "grad_est":
                noises = [np.random.normal(0, scale=1, size=(action_dim, state_dim)) for j in range(num_queries_per_step)]

                new_predictors = [linear_predictor + sm * noise for noise in noises]
            else:
                new_predictors = [linear_predictor] + [linear_predictor + learning_rate * gradient_estimation * np.power(decay,ii) for ii in range(4)]
            new_predictor_values = []
            #print(len(noises), new_predictors[0].shape, noises[0].shape)

            for predictor in new_predictors:
                cumulative_reward = 0

                observation, info = env.reset(seed=42)
                #print(state_dim, action_dim, observation.shape, predictor.shape)
                for i in range(horizon):
                    action = np.dot(predictor, observation)
                    observation, reward, terminated, truncated, info = env.step(action)
                    
                    cumulative_reward += reward
                    
                    if terminated or truncated:
                        observation, info = env.reset()

                new_predictor_values.append(cumulative_reward)
            for value in new_predictor_values:
                f.write(str(value) + ' ')
            if step % 10 == 0:
                print(step, 'out of', total_steps, np.mean(new_predictor_values), learning_rate, sm)
            f.write('\n')
            f.flush()
            ranks = np.argsort(-np.array(new_predictor_values))[:top_k].tolist()
                   
            if mode == "grad_est":
                gradient_estimation = np.zeros_like(linear_predictor)
                for d_ind, direction in enumerate(noises):
                    if d_ind in ranks: # in top k
                        d_rank = ranks.index(d_ind) + 1
                        gradient_estimation += (2 * d_rank - num_queries_per_step - 1) * direction
                    else: # not in top k
                        gradient_estimation += top_k * direction

                gradient_estimation /= (top_k * (top_k - 1) / 2 + top_k * (num_queries_per_step - top_k))
                if momentum is None:
                    momentum = gradient_estimation.copy()
                else:
                    momentum = mm * momentum + (1-mm) * gradient_estimation
                linear_predictor -= learning_rate * momentum 
                #mode = "line_search"
            else:
                linear_predictor = new_predictors[ranks[-1]]
                mode = "grad_est"
               

            hist.extend(new_predictor_values)
        
        total_hist.append(hist)
    
    env.close()
    f.close()
    
    return total_hist


def main():
    total_hist = run()
    #plot_hist(total_hist, "ZO-RankSGD_" + envir_name)

main()
