from __future__ import division

import copy
import random
import sys

import geatpy as gt
from platypus.config import PlatypusConfig
from platypus.evaluator import Job
from platypus.indicators import *
from platypus.operators import TournamentSelector, RandomGenerator, \
    GAOperator, SBX, PM
from platypus.weights import random_weights, chebyshev
from pymoo.factory import get_reference_directions
import numpy as np

import os

SCRIPT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(SCRIPT_DIR)
from moea import moea_problems
from moea.agents import operator_type, operator_parameter, neighbor_size
from moea.maenv_register import get_maenv
from multiagentenv import MultiAgentEnv

from collections import deque
import copy


class _EvaluateJob(Job):

    def __init__(self, solution):
        super(_EvaluateJob, self).__init__()
        self.solution = solution

    def run(self):
        self.solution.evaluate()


def get_weights(n_objs, n_points):
    tmp_weights = get_reference_directions("das-dennis", n_objs, n_points=n_points).tolist()
    if len(tmp_weights) != n_points:
        print(f'Generate {len(tmp_weights)} weights but asked to generate {n_points}')
    return tmp_weights


def choose_problems(env_name, n_obj):
    if 'ZDT' in env_name:
        if n_obj != 2:
            print("ZDT problem only has 2 n_objs")
        return getattr(moea_problems, env_name)()
    else:
        return getattr(moea_problems, env_name)(n_obj)


def get_variable_boundary(problem):
    nvars = problem.nvars
    lb = np.zeros(shape=nvars)
    ub = np.zeros(shape=nvars)
    for i in range(nvars):
        lb[i] = problem.types[i].min_value
        ub[i] = problem.types[i].max_value
    return lb, ub


EPSILON = sys.float_info.epsilon

algorithm_args_default = {
    'variator': None  # if variator is none, it will use the default variator
}


class MOEAEnv(MultiAgentEnv):
    def __init__(self, key="WFG6_3", n_ref_points=1000, population_size=210, seed=2022,
                 budget_ratio=50, test=False, wo_obs=False, algorithm_args=None, save_history=False,
                 early_stop=True, baseline=False, replay=False,
                 replay_dir=None, C=5.0, W=0.5, D=1.0):
        """

        @param key:
        @param n_ref_points:
        @param population_size:
        @param seed:
        @param budget_ratio:
        @param test:
        @param wo_obs:
        @param algorithm_args:
        @param save_history:
        @param early_stop:
        @param baseline:
        @param replay:
        @param replay_dir:
        @param C: Scaling factor
        @param W: Size of Sliding Window, W * Population Size
        @param D: Decaying factor
        """
        self.C = C
        self.W = W
        self.D = D
        if algorithm_args is None:
            algorithm_args = algorithm_args_default
        # Environment Change
        self.key = key
        self.func_choice = get_maenv(key)[0]
        self.nobjs_choice = get_maenv(key)[1]
        self.func_select = [(func, nobjs) for func in self.func_choice for nobjs in self.nobjs_choice]
        self.fun_index = 0
        random.shuffle(self.func_select)
        self.mixture = False
        if len(key.split("_")[-1]) > 1:
            self.mixture = True
        # Problem Related
        self.n_ref_points = n_ref_points
        self.population_size = population_size
        self.budget_ratio = budget_ratio
        self._init_problem()
        # MDP Related
        self.n_actions = 4
        if self.mixture is False:
            self.episode_limit = budget_ratio * self.n_objs + 1
        else:
            self.episode_limit = budget_ratio * 7 + 1
        self.early_stop = early_stop
        self.wo_obs = wo_obs  # Without Obs, False means that we need obs
        # MOEA/D Algorithm Related
        self.generator = RandomGenerator()
        self.selector = TournamentSelector(2)
        self.variator = algorithm_args['variator']
        self.evaluator = PlatypusConfig.default_evaluator
        self.baseline = baseline  # Used MOEA/D Default Operator Type
        self.moead_neighborhood_size = 30
        self.moead_neighborhood_maxsize = 30
        self.moead_eta = 2  # maximum number of replaced parent
        self.moead_delta = 0.8  # Use Neighbor or Whole Population, Important Parameters
        self.moead_weight_generator = random_weights
        # Not Important
        self.replay = replay
        self.replay_dir = replay_dir
        self.test = test
        self.save_history = save_history
        self.seed = seed
        self._init_static()

        self.replay_his = {
            "igd_his": [],
            "hv_his": [],
            "ndsort_ratio_his": [],
            "dis_his": [],
            "action_his": [],
            "population_his": []
        }

        # FRRMAB
        self.moead_neighborhood_size = 20
        self.slidingwindow_max = self.W * self.population_size
        self.scaling = self.C
        self.decaying = self.D
        self.frr = np.zeros(shape=self.n_actions, dtype=np.float32)
        self.ni = np.zeros(shape=self.n_actions, dtype=np.int32)

    def _init_problem(self):
        self.env_name = self.func_select[self.fun_index][0]
        self.n_objs = self.func_select[self.fun_index][1]
        self.problem = choose_problems(self.env_name, self.n_objs)
        self.lb, self.ub = get_variable_boundary(self.problem)
        self.problem_ref_points = self.problem.get_ref_set(n_ref_points=self.n_ref_points)
        self.igd_calculator = InvertedGenerationalDistance(reference_set=self.problem_ref_points)
        self.budget = self.population_size * self.budget_ratio * self.n_objs

    def _init_static(self):
        self.moead_generation = 0  # Discarded
        self.best_value = 1e6
        self.last_value = 1e6
        self.inital_value = None
        self.last_bonus = 0
        self.stag_count = 0
        self.stag_count_max = (self.budget // self.population_size) / 10
        self.hv_his = []
        self.nds_ratio_his = []
        self.ava_dist_his = []
        self.value_his = []
        self.action_seq_his = []  # record the sequential history of action
        self.info_reward_his = []
        self.info_obs_his = []
        self.info_igd_his = []

        self.slidingwindow_idx = deque()
        self.slidingwindow_fir = deque()

    def moead_initial(self):
        # self.weights = self.moead_weight_generator(self.problem.nobjs, self.population_size)
        # self.get_uniform_points is faster than the self.moead_weight_generator,
        # but may can not generate enough weights
        self.weights = get_weights(self.n_objs, self.population_size)
        self.neighborhoods = []  # the i-th element save the index of the neighborhoods of it
        for i in range(self.population_size):
            sorted_weights = self.moead_sort_weights(self.weights[i], self.weights)
            self.neighborhoods.append(sorted_weights[:self.moead_neighborhood_maxsize])
            self.moead_update_ideal(self.population[i])

    def frrmab(self):
        self.frr[self.ni == 0] = np.inf
        self.ni[self.ni == 0] = 1
        op = self.frr + self.scaling * np.sqrt(2 * np.log(np.sum(self.ni)) / self.ni)
        return np.argmax(op)

    def credit_assignment(self):
        reward = np.zeros(shape=self.n_actions, dtype=np.float32)
        self.ni = np.zeros(shape=self.n_actions, dtype=np.int32)
        decay = np.zeros(shape=self.n_actions, dtype=np.float32)
        for i in range(len(self.slidingwindow_idx)):
            op = self.slidingwindow_idx[i]
            fir = self.slidingwindow_fir[i]
            reward[op] += fir
            self.ni[op] += 1
        idx = np.argsort(np.argsort(reward)[::-1])  # Need to check
        for i in range(self.n_actions):
            decay[i] = (self.decaying ** idx[i]) * reward[i]
        self.frr = decay / (np.sum(decay) + EPSILON)

    def moead_step(self):
        """
        one step update in moea/d
        @param action: neighboor size; operator type; operator parameter
        :return:
        """
        scale = 0.5
        subproblems = self.moead_get_subproblems()
        self.offspring_list = []
        for index in subproblems:
            op = self.frrmab()
            mating_indices = self.moead_get_mating_indices(index)
            # Generate an offspring
            if index in mating_indices:
                mating_indices.remove(index)
            de_pool = np.random.choice(mating_indices, 5, replace=False)
            offspring = self.variator.mutation.evolve(
                operator_type(op, self.population, index, de_pool, scale, self.lb, self.ub))
            self.evaluate_all([offspring])
            self.offspring_list.append(offspring)
            self.moead_update_ideal(offspring)
            eta = self.moead_update_solution(offspring, mating_indices)  # selection
            self.slidingwindow_idx.append(op)
            self.slidingwindow_fir.append(eta)
            if len(self.slidingwindow_idx) > self.slidingwindow_max:
                self.slidingwindow_idx.popleft()
                self.slidingwindow_fir.popleft()
            self.credit_assignment()
        self.moead_generation += 1

    def update_igd(self, value):
        self.value_his.append(value)
        if value < self.best_value:
            self.stag_count = 0
            self.best_value = value
        else:
            self.stag_count += 1
        self.last_value = value

    def step(self):
        """
        Returns reward, terminated, info
        :type action: int or np.int
        :param action: for dimension: Neighbor Size, Operator Type, Operator Parameter, Adaptive Weights
                [x, x, x, x]: shape=n_agents
        :return: obs, reward, is_done, info
        """

        self.moead_step()
        value = self.get_igd()
        self.update_igd(value)
        self.info_igd_his.append(self.last_value)  # Logger
        if self.moead_generation == 0 and self.test:
            # if test, it will report basic information of this run
            print("Problem {} ite {}, best igd is {}, last igd is {}".format(self.problem,
                                                                             self.nfe,
                                                                             self.best_value,
                                                                             self.last_value))
        # if stop, then return the information
        if self.moead_generation >= self.budget // self.population_size or \
                (self.stag_count > (self.budget // self.population_size) / 10 and self.early_stop):
            self.done = True
            if self.replay:
                self.update_replay()
        else:
            self.done = False
        if self.save_history:
            return self.done, {'best_igd': self.best_value,
                               'last_igd': self.last_value,
                               'igd_his': self.info_igd_his,
                               'reward_his': self.info_reward_his,
                               'obs_his': self.info_obs_his}
        else:
            return self.done, {'best_igd': self.best_value, 'last_igd': self.last_value}

    def reset(self):
        """ Returns initial observations and states """
        self.done = False
        self._init_problem()
        self._init_static()
        self.population = [self.generator.generate(self.problem) for _ in range(self.population_size)]
        self.evaluate_all(self.population)
        self.inital_value = self.get_igd()
        self.best_value = self.inital_value
        if self.n_objs < 5:
            tmp_population = [self.generator.generate(self.problem) for _ in range(int(1e5))]
            tmp_feasible = [s for s in tmp_population if s.constraint_violation == 0.0]
            self.evaluate_all(tmp_feasible)
            self.archive_maximum = (
                    1.1 * np.array([max([s.objectives[i] for s in tmp_feasible]) for i in range(self.n_objs)])).tolist()
            del tmp_population, tmp_feasible
        else:
            self.archive_maximum = [max([s.objectives[i] for s in self.population]) for i in range(self.n_objs)]
        self.archive_minimum = [min([s.objectives[i] for s in self.population]) for i in range(self.n_objs)]
        self.ideal_point = copy.deepcopy(self.archive_minimum)
        if self.variator is None:
            self.variator = GAOperator(SBX(probability=1.0, distribution_index=20.0),
                                       PM(probability=1 / self.problem.nvars))
        # Change Task Function
        self.fun_index += 1
        if self.fun_index == len(self.func_select):
            self.fun_index = 0
            random.shuffle(self.func_select)
        self.moead_initial()

    def update_replay(self):
        self.replay_his["igd_his"].append(self.info_igd_his)
        self.replay_his["hv_his"].append(self.hv_his)
        self.replay_his["ndsort_ratio_his"].append(self.nds_ratio_his)
        self.replay_his["dis_his"].append(self.ava_dist_his)
        self.replay_his["action_his"].append(self.action_seq_his)
        self.replay_his["population_his"].append(self.population)

    def save_replay(self):
        timesteps = []
        for name in os.listdir(self.replay_dir):
            full_name = os.path.join(self.replay_dir, name)
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))
        timestep_to_load = max(timesteps)
        replay_path = os.path.join(self.replay_dir, str(timestep_to_load), "replay.npz")
        np.savez(file=replay_path, info_stack=self.replay_his)

    def get_ref_set(self):
        return self.problem_ref_points

    def get_uniform_points(self, n_points):
        """
        return a list of points
        :param n_points:
        :return:
        """
        uniformPoints, _ = gt.crtup(self.n_objs, n_points)
        return uniformPoints.tolist()

    def evaluate_all(self, solutions):
        unevaluated = [s for s in solutions if not s.evaluated]

        jobs = [_EvaluateJob(s) for s in unevaluated]
        results = self.evaluator.evaluate_all(jobs)

        # if needed, update the original solution with the results
        for i, result in enumerate(results):
            if unevaluated[i] != result.solution:
                unevaluated[i].variables[:] = result.solution.variables[:]
                unevaluated[i].objectives[:] = result.solution.objectives[:]
                unevaluated[i].constraints[:] = result.solution.constraints[:]
                unevaluated[i].constraint_violation = result.solution.constraint_violation
                unevaluated[i].feasible = result.solution.feasible
                unevaluated[i].evaluated = result.solution.evaluated

    def moead_update_ideal(self, solution):
        for i in range(self.problem.nobjs):
            self.ideal_point[i] = min(self.ideal_point[i], solution.objectives[i])

    def moead_calculate_fitness(self, solution, weights):
        return chebyshev(solution, self.ideal_point, weights)

    def moead_update_solution(self, solution, mating_indices):
        """
        repair solution, make constraint satisfiable
        :param solution:
        :param mating_indices:
        :return:
        """
        c = 0
        random.shuffle(mating_indices)
        eta = 0
        for i in mating_indices:
            candidate = self.population[i]
            weights = self.weights[i]
            replace = False

            if solution.constraint_violation > 0.0 and candidate.constraint_violation > 0.0:
                if solution.constraint_violation < candidate.constraint_violation:
                    replace = True
            elif candidate.constraint_violation > 0.0:
                replace = True
            elif solution.constraint_violation > 0.0:
                pass
            elif self.moead_calculate_fitness(solution, weights) < self.moead_calculate_fitness(candidate, weights):
                replace = True
                eta += (1 - self.moead_calculate_fitness(solution, weights) / self.moead_calculate_fitness(candidate,
                                                                                                           weights))

            if replace:
                self.population[i] = copy.deepcopy(solution)
                c = c + 1

            if c >= self.moead_eta:
                break

        return eta

    @staticmethod
    def moead_sort_weights(base, weights):
        """Returns the index of weights nearest to the base weight."""

        def compare(weight1, weight2):
            dist1 = math.sqrt(sum([math.pow(base[i] - weight1[1][i], 2.0) for i in range(len(base))]))
            dist2 = math.sqrt(sum([math.pow(base[i] - weight2[1][i], 2.0) for i in range(len(base))]))

            if dist1 < dist2:
                return -1
            elif dist1 > dist2:
                return 1
            else:
                return 0

        sorted_weights = sorted(enumerate(weights), key=functools.cmp_to_key(compare))
        return [i[0] for i in sorted_weights]

    def moead_get_subproblems(self):
        """
        Determines the subproblems to search.
        If :code:`utility_update` has been set, then this method follows the
        utility-based moea/D search.
        Otherwise, it follows the original moea/D specification.
        """
        indices = list(range(self.population_size))
        random.shuffle(indices)
        return indices

    def moead_get_mating_indices(self, index):
        """Determines the mating indices.

        Returns the population members that are considered during mating.  With
        probability :code:`delta`, the neighborhood is returned.  Otherwise,
        the entire population is returned.
        """
        if random.uniform(0.0, 1.0) <= self.moead_delta:
            return self.neighborhoods[index][:self.moead_neighborhood_size]
        else:
            return list(range(self.population_size))

    def get_igd(self):
        return self.igd_calculator.calculate(self.population)
