import os
import torch
import argparse
import math
import importlib
import random

import numpy as np
import torch.backends.cudnn as cudnn
# from bayes_opt import BayesianOptimization
import cma
from es.es import SimpleGA, CMAES, PEPG, OpenES
from open_source.rlpyt.rlpyt.envs.base import Env, EnvStep, EnvSpaces

# from envs.TwoStageDiffAmpEnv import TwoStageDiffAmpEnv
# from envs.LDOEnv import LDOEnv
# from envs.TwoStageTransAmpEnv import TwoStageTransAmpEnv
# from envs.ThreeStageDiffTransAmpEnv import ThreeStageDiffTransAmpEnv
# from envs.NGspiceOpampEnv_rlpyt import NGspiceOpampEnv_rlpyt
from envs.NGspiceOpampEnv_pvt import NGspiceOpampEnv_pvt
# from envs.foldedCascodeEnv_pvt_gen import foldedCascodeEnv_pvt_gen
# from envs.strongArmEnv_pvt_gen import strongArmEnv_pvt_gen
from copy import deepcopy
from agent.utils.utils import AttrDict, get_output_folder
from tensorboardX import SummaryWriter
import pdb


def fit_func(action):
    global env_cl
    global env
    global writer
    global kwargs
    global step
    global episode
    global episode_steps
    global episode_reward
    global max_reward
    global max_action
    global max_observation
    # normalize
    action = np.tanh(action)
    step_info = env.step(action)
    # pdb.set_trace()
    reward = step_info[1]
    if env_cl == NGspiceOpampEnv_pvt:
        reward_dict = {}
    else:
        reward_dict = step_info[5]
    # observation2, reward, done, info = env.step(action, episode)
    # max_reward = max(max_reward, reward)

    print('==> Episode={} \t max_reward={} \t current_reward={}'.format(episode, max_reward, reward))
    # observation2 = deepcopy(observation2)
    # # write out
    writer.add_scalar('reward', reward, episode)
    writer.add_scalar('max_reward', max_reward, episode)

    if reward > max_reward:
        max_reward = reward
        # writer.add_text('Info', ' Stp: {} \n Reward: {}\n Params: {}\n Perf: {}\n Fracs: {}'.format(episode, info['fom'], info['absolute_sizings'], info['metrics'], info['fracs']), global_step=episode)

    # writer.add_scalar('num_saturate', observation2[0], episode)
    # writer.add_scalar('bw', observation2[1], episode)
    # writer.add_scalar('phase_margin_cm', observation2[2], episode)
    # writer.add_scalar('phase_margin_dm', observation2[3], episode)
    # writer.add_scalar('gain', observation2[4], episode)
    # writer.add_scalar('power', observation2[5], episode)
    # writer.add_scalar('noise', observation2[6], episode)
    # pdb.set_trace()
    writer.add_scalars('actions_origin', {'action'+str(i): action[i] for i in range(kwargs['nb_actions'])}, episode)

    step += 1
    episode_steps += 1
    episode_reward += reward
    # end of episode
    done = step_info[2]
    if done:
        # reset
        episode_steps = 0
        episode_reward = 0.
        episode += 1
    # pdb.set_trace()
    return reward, reward_dict

# defines a function to use solver to solve fit_func
def test_solver(solver):
  history = []
  reward_all = []
  for j in range(MAX_ITERATION):
    solutions = solver.ask()
    fitness_list = np.zeros(solver.popsize)
    reward_list = []
    for i in range(solver.popsize):
        fitness_list[i], reward_l = fit_func(solutions[i])
        reward_list.append(reward_l)
    solver.tell(fitness_list)
    result = solver.result() # first element is the best solution, second element is the best fitness
    history.append(result[1])
    reward_all.append(reward_list)
    if (j+1) % 100 == 0:
      print("fitness at iteration", (j+1), result[1])
  print("local optimum discovered by solver:\n", result[0])
  print("fitness score at this local optimum:", result[1])
  # pdb.set_trace()
  return history, reward_all


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('kwargs', default=None)
    parser.add_argument('--devices', default=None)
    parser.add_argument('--remote_server', default=None)
    parser.add_argument('--env', type=str)

    # parse arguments
    args = parser.parse_args()
    args.kwargs = args.kwargs.replace('/', '.').replace('.py', '')

    # set devices
    if args.devices is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
        cudnn.benchmark = True
        device = 'cuda'
    else:
        device = 'cpu'

    global env_cl
    global env
    global writer
    global kwargs
    global observation
    global step
    global episode
    global episode_steps
    global episode_reward
    global max_reward
    global max_action
    global max_observation
    global MAX_ITERATION
    global NPARAMS
    global NPOPULATION

    # load kwargs
    kwargs = {}
    prefix = ''
    for name in args.kwargs.split('.')[:-1]:
        prefix += name + '.'
        kwargs = {**kwargs, **importlib.import_module(prefix + 'defaults').kwargs}
    kwargs = {**kwargs, **importlib.import_module(args.kwargs).kwargs}

    # save runs dir for further use
    kwargs['runs_dir'] = args.kwargs
    kwargs['gpu_id'] = args.devices
    kwargs['remote_server'] = args.remote_server

    print('==> parsed arguments')
    for k, v in kwargs.items():
        print('[{}] = {}'.format(k, v))

    # set random seed
    if 'seed' in kwargs and kwargs['seed'] is not None:
        random.seed(kwargs['seed'])
        np.random.seed(kwargs['seed'])
        torch.manual_seed(kwargs['seed'])

        if device == 'cuda':
            cudnn.deterministic = True
            torch.cuda.manual_seed(kwargs['seed'])

    kwargs = AttrDict(kwargs)

    nb_actions = kwargs['nb_actions']
    NPARAMS = kwargs['nb_actions']
    NPOPULATION = kwargs['rmsize']
    # pdb.set_trace()
    # tensorboard file writer
    kwargs['log_dir'] = 'runs'
    log_dir = 'runs/es' + '/' + kwargs['runs_dir']
    log_dir = get_output_folder(log_dir, 'es_' + str(nb_actions))
    writer = SummaryWriter(log_dir=log_dir)

    # get env
    # env = TwoStageDiffAmpEnv(tb_writer=writer, kwargs=kwargs)

    # initial_states = env.initial_states
    # observation = initial_states
    step = episode = episode_steps = 0
    episode_reward = 0.
    max_reward = -6
    max_action = []
    max_observation = []
    MAX_ITERATION = kwargs['train_global_stp']

    # Bounded region of parameter space
    pbounds = {}
    for i in range(kwargs['nb_actions']):
        pbounds['x'+str(i)] = (-1, 1)

    kwargs['log'] = True
    # env = NGspiceOpampEnv_rlpyt(kwargs=kwargs)
    if args.env == 'ng' :
        env_cl = NGspiceOpampEnv_pvt
        env = NGspiceOpampEnv_pvt(kwargs=kwargs)
    elif args.env == 'strongArm' :
        env_cl = strongArmEnv_pvt_gen
        env = strongArmEnv_pvt_gen(kwargs=kwargs, writer=writer)
    elif args.env == 'fold' :
        env_cl = foldedCascodeEnv_pvt_gen
        env = foldedCascodeEnv_pvt_gen(kwargs=kwargs, writer=writer)
    # if nb_actions == 11:
    #     # get env
    #     env = TwoStageTransAmpEnv(tb_writer=writer, kwargs=kwargs)
    # elif nb_actions == 15:
    #     # get env
    #     env = LDOEnv(tb_writer=writer, kwargs=kwargs)
    # elif nb_actions == 18:
    #     # get env
    #     env = TwoStageDiffAmpEnv(tb_writer=writer, kwargs=kwargs)
    # elif nb_actions == 19:
    #     # get env
    #     env = ThreeStageDiffTransAmpEnv(tb_writer=writer, kwargs=kwargs)
    # else:
    #     raise NotImplementedError

    # optimizer.maximize(
    #                 init_points=kwargs['warmup_global_stp'],
    #                 n_iter=kwargs['train_global_stp'],
    #                 )

    # defines CMA-ES algorithm solver
    cmaes = CMAES(NPARAMS,
                  popsize=NPOPULATION,
                  weight_decay=0.0,
                  sigma_init=0.5
                  )

    cma_history, rewards = test_solver(cmaes)
    import pickle
    if args.env == 'fold':
        file_to_write = open("es_fold.pkl", "wb")
    elif args.env == 'strongArm':
        file_to_write = open("es_strongArm.pkl", "wb")
    elif args.env == 'ng':
        file_to_write = open("es_ng.pkl", "wb")
    pickle.dump(rewards, file_to_write)
