import skopt
import random
import torch

import agents
import environment

dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### bayesian optimization for the goal-conditioned agent training parameters ###
SPACE = [
    skopt.space.Categorical([50,100,200], name='N_batch', prior=[1/3,1/3,1/3]),
    skopt.space.Real(5e-4,5e-3, name='lr', prior='log-uniform'),
    skopt.space.Integer(2, 5, name='n_iter_max', prior='log-uniform'),
    skopt.space.Real(1e-4, 1., name='weight_decay', prior='log-uniform'),
    skopt.space.Integer(1,3, name='n_hidden_layers', prior='uniform'),
    skopt.space.Integer(128,512, name='d_hidden_layers', prior='log-uniform')]


@skopt.utils.use_named_args(SPACE)
def evaluate_agent(**params):
    print(params)
    result = 0
    H = 45
    d_hidden = [params['d_hidden_layers']]*params['n_hidden_layers']
    print(d_hidden)
    n_tests = 10
    for _ in range(n_tests):
        env = environment.ENV(H)
        agent = agents.Agent_MLP_GC_BC(env, d_hidden).to(dev)
        result+=agent.train_agent(N_batch=params['N_batch'],lr=params['lr'],n_iter_max=params['n_iter_max']*10000, weight_decay=params['weight_decay'])
    print("\tresults : "+str(result/n_tests))
    return -result/n_tests

results = skopt.forest_minimize(evaluate_agent, SPACE, n_calls=40, n_initial_points=20, base_estimator='ET')

skopt.dump(results, 'results.pkl')

print(results)

print("best point : "+str(results['x'])+" evaluate at "+str(-results['fun']))


### bayesian optimization for the Q-learning agent training parameters ###
SPACE = [
    skopt.space.Categorical([10,20], name='N_batch', prior=[1/2,1/2]),
    skopt.space.Real(1e-4,1e-2, name='lr', prior='log-uniform'),
    skopt.space.Real(1e-8, 1e-1, name='weight_decay', prior='log-uniform'),
    skopt.space.Integer(2,3, name='n_hidden_layers', prior='uniform'),
    skopt.space.Integer(128,512, name='d_hidden_layers', prior='log-uniform'),
    skopt.space.Categorical([1000,2000], name='n_iter_Q_update', prior=[1/2,1/2]),
    skopt.space.Real(1e-2,1., name='epsilon', prior='log-uniform')]


@skopt.utils.use_named_args(SPACE)
def evaluate_agent(**params):
    print(params)
    result = 0
    H = 12
    d_hidden = [params['d_hidden_layers']]*params['n_hidden_layers']
    print(d_hidden)
    n_tests = 5
    for _ in range(n_tests):
        env = environment.ENV(H)
        agent = agents.Agent_Q(env, d_hidden).to(dev)
        result+=agent.train_agent(N_batch=params['N_batch'],lr=params['lr'], weight_decay=params['weight_decay'], n_iter_Q_update=params['n_iter_Q_update'])
    print("\tresults : "+str(result/n_tests))
    if result==n_tests:
        n_tests = 20
        result = 0
        for _ in range(n_tests):
            env = environment.ENV(H)
            agent = agents.Agent_Q(env, d_hidden).to(dev)
            result+=agent.train_agent(N_batch=params['N_batch'],lr=params['lr'], weight_decay=params['weight_decay'], n_iter_Q_update=params['n_iter_Q_update'])
        print("\tresults : "+str(result/n_tests))
        return -result/n_tests
    else:
        return -result/n_tests

results = skopt.forest_minimize(evaluate_agent, SPACE, n_calls=75, n_initial_points=45, base_estimator='ET')

skopt.dump(results, 'results.pkl')

print(results)

print("best point : "+str(results['x'])+" evaluate at "+str(-results['fun']))


### bayesian optimization for the PPO agent training parameters ###
SPACE = [
    skopt.space.Categorical([10,50,100], name='N_batch', prior=[1/3,1/3,1/3]),
    skopt.space.Real(1e-4,1e-2, name='lr', prior='log-uniform'),
    skopt.space.Real(1e-8, 1e-1, name='weight_decay', prior='log-uniform'),
    skopt.space.Integer(128, 512, name='d_hidden_layers', prior='log-uniform'),
    skopt.space.Real(1e-2, 0.2, name='epsilon', prior='log-uniform'),
    skopt.space.Real(1e-3,1e-1, name='beta', prior='log-uniform')]


@skopt.utils.use_named_args(SPACE)
def evaluate_agent(**params):
    print(params)
    result = 0
    H = 17
    d_hidden = [params['d_hidden_layers']]*2
    print(d_hidden)
    n_tests = 5
    for _ in range(n_tests):
        env = environment.ENV(H)
        agent = agents.Agent_PPO(env, d_hidden).to(dev)
        result+=agent.train_agent(N_batch=params['N_batch'],lr=params['lr'], weight_decay=params['weight_decay'], epsilon=params['epsilon'], beta=params['beta'])
    print("\tresults : "+str(result/n_tests))
    if result==n_tests:
        n_tests = 20
        print("pass first test")
        result = 0
        for _ in range(n_tests):
            env = environment.ENV(H)
            agent = agents.Agent_PPO(env, d_hidden).to(dev)
            result+=agent.train_agent(N_batch=params['N_batch'],lr=params['lr'], weight_decay=params['weight_decay'], epsilon=params['epsilon'], beta=params['beta'])
        print("\tresults : "+str(result/n_tests))
        return -result/n_tests
    else:
        return -result/n_tests

results = skopt.forest_minimize(evaluate_agent, SPACE, n_calls=100, n_initial_points=75, base_estimator='ET')

skopt.dump(results, 'results.pkl')

print(results)

print("best point : "+str(results['x'])+" evaluate at "+str(-results['fun']))