from cma import CMAEvolutionStrategy
import gym
import mujoco_py
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import argparse
plt.style.use('ggplot')

parser = argparse.ArgumentParser()
parser.add_argument('--sm', default = 0.01, type = float) 
parser.add_argument('--env', default = "Swimmer-v2", choices = ["HalfCheetah-v2", "Swimmer-v2", "Reacher-v2"]) 
parser.add_argument('--steps', default = 10000, type = int) 
parser.add_argument('--reps', default = 1, type = int) 

args = parser.parse_args()

envir_name = args.env # "Hopper-v2" # "HalfCheetah-v2" # "Swimmer-v2" "Reacher-v2"

def plot_hist(total_hist, hist_name):
    total_hist = np.array(total_hist)
    total_mean = total_hist.mean(0)
    total_std = total_hist.std(0)
    plt.plot(total_mean, "--", label=hist_name, linewidth=2, alpha=0.8)
    total_steps = np.arange(len(total_mean))
    plt.fill_between(total_steps, total_mean - total_std,total_mean + total_std, alpha=0.3)

    plot_hist(total_hist, "ZO-RankSGD")
    plt.ylabel("objective value")
    plt.xlabel("num query")
    plt.legend()
    plt.show()


def run():
    np.random.seed(1234)

    reps = args.reps # 10 independent runs, for plotting mean+-std
    total_steps = args.steps
    horizon = 1000

    # (m,k)-ranking oracle
    num_queries_per_step = 5
    smoothing_scale = args.sm

    total_hist = []
    f = open('logs_cma_smoothing_001_sm_{}_steps_{}_reps_{}'.format(smoothing_scale,total_steps,reps) + envir_name + '.txt', 'w')
    env = gym.make(envir_name)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    linear_predictor_init = np.random.normal(0, 1, size=(action_dim, state_dim))
    for _ in range(reps):
        print('rep number:', _)
        f.write('rep number: ' + str(_) + '\n')
        hist = []
        
        es = CMAEvolutionStrategy(list(linear_predictor_init.flatten()), smoothing_scale, {'popsize':num_queries_per_step})


        for step in range(total_steps):
          
            if step % 10 == 0:
                print(step, 'out of', total_steps)

            solutions = es.ask()
                
            new_predictors = [noise.reshape(action_dim,state_dim) for noise in solutions]
            new_predictor_values = []
            #print(len(noises), new_predictors[0].shape, noises[0].shape)
            
            for predictor in new_predictors:
                cumulative_reward = 0

                observation, info = env.reset(seed=42)
                #print(state_dim, action_dim, observation.shape, predictor.shape)
                for i in range(horizon):
                    action = np.dot(predictor, observation)
                    observation, reward, terminated, truncated, info = env.step(action)
                    
                    cumulative_reward += reward
                    
                    if terminated or truncated:
                        observation, info = env.reset()

                new_predictor_values.append(cumulative_reward)
            for value in new_predictor_values:
                f.write(str(value) + ' ')
            f.write('\n')
            f.flush()
            
            if step % 10 == 0:
                print(step, 'out of', total_steps, np.mean(new_predictor_values))
   
                        
            es.tell(solutions, [-v for v in new_predictor_values])

            hist.extend(new_predictor_values)
        
        total_hist.append(hist)
    
    env.close()
    f.close()
    
    return total_hist


def main():
    total_hist = run()
    #plot_hist(total_hist, "CMA-ES_" + envir_name)

main()
