import traceback
import gym
import numpy as np
from gym import spaces
from .evaluate_nas201 import evaluate_nas201
from .evaluate_trans101 import *
from .evaluate_nas101 import evaluate_nas101
import math
import concurrent


def process_individual(individual, args):
    try:
        code = individual['code']
        local_env = {}
        exec(code, local_env, local_env)

        func = None
        for v in local_env.values():
            if callable(v):
                func = v
                break

        if func is None:
            return 0. 

        if args.benchmark == 'nasbench201':
            score = evaluate_nas201(args, func)
        elif args.benchmark == 'transbench101':
            score = evaluate_trans101(args, func)
        elif args.benchmark == 'nasbench101':
            score = evaluate_nas101(args, func)
        else:
            score = 0.

        score = 0. if math.isnan(score) else score
        return abs(score)
    
    except Exception:
        error_trace = traceback.format_exc()
        print(error_trace)
        return 0.

class StrategyGenEnv(gym.Env):
    def __init__(self, num_strategies=3, history_len=5):
        super(StrategyGenEnv, self).__init__()

        self.num_strategies = num_strategies
        self.history_len = history_len
        self.strategy_history = []
        self.reward_history = []

        self.action_space = spaces.Discrete(num_strategies)

        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(history_len + history_len * num_strategies,),
            dtype=np.float32
        )

    def reset(self):
        self.strategy_history = []
        self.reward_history = []
        return self._get_state()

    def step(self, action, new_pop, args):
        assert self.action_space.contains(action)

        reward = self.evaluate_strategy(new_pop, args) 

        self.strategy_history.append(action)
        self.reward_history.append(reward)

        if len(self.strategy_history) > self.history_len:
            self.strategy_history.pop(0)
            self.reward_history.pop(0)

        next_state = self._get_state()
        return next_state, reward

    def evaluate_strategy(self, pop, args):
        scores = []
        timeout_duration = args.timeout

        for individual in pop:
            if individual.get('score') is not None:
                scores.append(individual['score'])

        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_individual = {
                executor.submit(process_individual, individual, args): individual
                for individual in pop if individual.get('score') is None
            }
            for future in concurrent.futures.as_completed(future_to_individual):
                try:
                    score = future.result(timeout=timeout_duration)
                except concurrent.futures.TimeoutError:
                    score = 0.
                except Exception:
                    score = 0.
                future_to_individual[future]['score'] = score
                scores.append(score)

        res = 0
        for ind in pop:
            res += abs(ind['score'])

        return res / len(pop)

    def _get_state(self):
        reward_part = np.zeros(self.history_len)
        strategy_part = np.zeros((self.history_len, self.num_strategies))

        for i, (s, r) in enumerate(zip(self.strategy_history, self.reward_history)):
            reward_part[i] = r
            strategy_part[i][s] = 1

        state = np.concatenate([reward_part, strategy_part.flatten()])
        return state.astype(np.float32)


