from env import TSCEnv
import itertools
from typing import List
from agent.gp_light.gp import gplearn


def run_a_step(env: TSCEnv, n_obs: List):
    n_action = []
    for agent in env.n_agent:
        action = agent.pick_action(n_obs, None)
        n_action.append(action)
    n_next_obs, n_rew, n_done, info = env.step(n_action)
    return n_next_obs, n_rew, n_done, info, n_action[0]


def run_an_episode(env: TSCEnv, config: dict):
    n_obs = env.reset()  # Observations of n agents
    n_done = [False]
    info = {}

    for config['current_episode_step_idx'] in itertools.count(start=0, step=config['action_interval']):
        if config['current_episode_step_idx'] >= config['num_step'] or all(n_done):
            break
        n_next_obs, n_rew, n_done, info, action = run_a_step(env, n_obs)
        n_obs = n_next_obs

    return info


class GPLOptimizer:
    def __init__(self, env: TSCEnv, config):
        self.env = env
        self.config = config

    def evaluate(self, func):
        for agent in self.env.n_agent:
            agent.reset()
            agent.func = func
        info = run_an_episode(self.env, self.config)
        return info['world_2_average_travel_time'][0]

    def search(self, popsize=25, iterations=20):
        def evl(f):
            return self.evaluate(f),
        expr, func = gplearn(evl, self.env.n_agent[0].get_phase_num_features(), popsize=popsize, iterations=iterations)
        return expr, func
