from env import TSCEnv
from typing import List
import itertools
from utilities.utils import set_seed
from .sym_domain import SymDomain, StateTransition, get_pareto_front
import numpy as np
import geppy as gep
import operator
from .mcts import MCTS
from skopt import gp_minimize
from skopt.space import Integer


def run_a_step(env: TSCEnv, n_obs: List):
    n_action = []
    for agent in env.n_agent:
        action = agent.pick_action(n_obs, None)
        n_action.append(action)
    n_next_obs, n_rew, n_done, info = env.step(n_action)
    return n_next_obs, n_rew, n_done, info, n_action[0]


def run_an_episode(env: TSCEnv, config: dict):
    n_obs = env.reset()  # Observations of n agents
    n_done = [False]
    info = {}

    for config['current_episode_step_idx'] in itertools.count(start=0, step=config['action_interval']):
        if config['current_episode_step_idx'] >= config['num_step'] or all(n_done):
            break
        n_next_obs, n_rew, n_done, info, action = run_a_step(env, n_obs)
        n_obs = n_next_obs

    return info
    # travel_time = info['world_2_average_travel_time'][0]
    # return travel_time


def div(a, b):
    if abs(b) < 1e-6:
        return 1
    else:
        return a / b


def get_pset():
    pset = gep.PrimitiveSet('TSC_pset',
                            input_names=['wi', 'wo', 'ci', 'co', 'di', 'do', 'li', 'lo'])
    pset.add_function(operator.add, 2)
    # pset.add_function(operator.sub, 2)
    pset.add_function(operator.neg, 1)
    pset.add_function(operator.mul, 2)
    pset.add_function(div, 2)
    pset.add_function(max, 2)
    pset.add_function(min, 2)

    return pset


class SymbolicOptimizer:
    def __init__(self, env: TSCEnv, config):
        self.env = env
        self.config = config

        self.pset = get_pset()

        self.exploration = 1 / np.sqrt(2)
        self.domain = SymDomain(self.pset, self, None, func_length=self.config['func_length'])

    def evaluate(self, func):
        for agent in self.env.n_agent:
            agent.reset()
            agent.func = func
        info = run_an_episode(self.env, self.config)
        return info['world_2_average_travel_time'][0]

    def evaluate_eff_range(self, eff_range):
        for agent in self.env.n_agent:
            agent.reset()
            agent.func = lambda wi, wo, ci, co, di, do, li, lo: di
            agent.eff_range = eff_range
        info = run_an_episode(self.env, self.config)
        return info['world_2_average_travel_time'][0]

    def run_an_episode(self, func):
        for agent in self.env.n_agent:
            agent.reset()
            agent.func = func
        return run_an_episode(self.env, self.config)

    def constant_search(self, steps):
        result = gp_minimize(self.evaluate_eff_range,
                             dimensions=[Integer(5, 220)],
                             n_calls=steps, n_initial_points=3, x0=None)
        print(f'best effective range: {result.x[0]}')
        for ag in self.env.n_agent:
            ag.eff_range = result.x[0]

    def search(self, episodes, iterations=50):
        # self.constant_search(50)

        bst_reward = -float('inf')
        MAX_LEN = 10
        best_states = []
        episode = 0
        while self.domain.num_eval < episodes:
            episode += 1
            mcts = MCTS(self.domain, exploration=self.exploration)
            self.exploration *= 10
            self.domain.ST = StateTransition([s for r, s in best_states]).build()
            for it in range(iterations):
                node, final_state, reward = mcts.step()
                print(f'\r{episode}-{it} {1 / reward:.4f}| {1 / mcts.best_reward:.4f}| {final_state} | {node.N}',
                      end='')
                if reward > bst_reward:
                    bst_reward = reward

                if len(best_states):
                    tmp = best_states[0]
                    if reward > tmp[0] and final_state not in [prog for r, prog in best_states]:
                        best_states.append((reward, final_state))
                        if len(best_states) > MAX_LEN:
                            best_states = best_states[-MAX_LEN:]
                else:
                    best_states.append((reward, final_state))
                best_states.sort(key=lambda x:(x[0], -len(x[1])))

            print('')
            print('-' * 100)
            print('States: ', [f'{1/r:.4f}=>{len(prog)}' for r, prog in  reversed(best_states)])
            print(1/best_states[0][0], len(best_states),
                  f'eval num: {self.domain.num_eval}/{episodes}')

        return get_pareto_front(self.domain)

