from __future__ import division
import copy
import random
import sys
from operator import itemgetter
import geatpy as gt
import scipy
import tianshou
from platypus.config import PlatypusConfig
from platypus.core import Archive
from platypus.evaluator import Job
from platypus.indicators import *
from platypus.operators import TournamentSelector, RandomGenerator, \
    GAOperator, SBX, PM
from platypus.weights import random_weights, chebyshev
from pymoo.factory import get_reference_directions
from scipy.spatial import distance
import numpy as np

import os

SCRIPT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(SCRIPT_DIR)
from moea import moea_problems
from moea.agents import operator_type, operator_parameter, neighbor_size
from moea.maenv_register import get_maenv
from multiagentenv import MultiAgentEnv


class _EvaluateJob(Job):

    def __init__(self, solution):
        super(_EvaluateJob, self).__init__()
        self.solution = solution

    def run(self):
        self.solution.evaluate()


def get_weights(n_objs, n_points):
    tmp_weights = get_reference_directions("das-dennis", n_objs, n_points=n_points).tolist()
    if len(tmp_weights) != n_points:
        print(f'Generate {len(tmp_weights)} weights but asked to generate {n_points}')
    return tmp_weights


def choose_problems(env_name, n_obj):
    if 'ZDT' in env_name:
        if n_obj != 2:
            print("ZDT problem only has 2 n_objs")
        return getattr(moea_problems, env_name)()
    else:
        return getattr(moea_problems, env_name)(n_obj)


def get_variable_boundary(problem):
    nvars = problem.nvars
    lb = np.zeros(shape=nvars)
    ub = np.zeros(shape=nvars)
    for i in range(nvars):
        lb[i] = problem.types[i].min_value
        ub[i] = problem.types[i].max_value
    return lb, ub


EPSILON = sys.float_info.epsilon


class MOEAEnv(MultiAgentEnv):
    def __init__(self, key="WFG6_3", n_ref_points=1000, population_size=210, seed=2022,
                 budget_ratio=50, test=False, wo_obs=False, save_history=False,
                 early_stop=True, baseline=False, adaptive_open=False, replay=False,
                 replay_dir=None, ban_agent=4, reward_type=0):
        """

        @param key: Task Name
        @param n_ref_points:
        @param population_size:
        @param seed:
        @param budget_ratio:
        @param test:
        @param wo_obs:
        @param save_history: Save History in return Info
        @param early_stop:
        @param baseline: Use Baseline Operator
        @param adaptive_open: Use Adaptive Weights
        @param replay: Save Replay in Replay dir
        @param replay_dir:
        @param ban_agent: 0,1,2,3 , if not , means no ban
        @param reward_type: 0 is Triangles, 1,2,3 is defined in DEDDQN
        """
        # Environment Change
        self.key = key
        self.func_choice = get_maenv(key)[0]
        self.nobjs_choice = get_maenv(key)[1]
        self.func_select = [(func, nobjs) for func in self.func_choice for nobjs in self.nobjs_choice]
        self.fun_index = 0
        random.shuffle(self.func_select)
        self.mixture = False
        if len(key.split("_")[-1]) > 1:
            self.mixture = True  # Multi-nobjs training
        # Problem Related
        self.n_ref_points = n_ref_points
        self.population_size = population_size
        self.budget_ratio = budget_ratio
        self._init_problem()
        self.ban_agent = ban_agent
        self.reward_type = reward_type
        # Adaptive Weights Agent Related
        self.adaptive_open = adaptive_open
        self._init_adaptive_weights()
        # MDP Related
        self.n_agents = 4
        self.n_actions = 4
        self.n_obs = 22
        if self.mixture is False:
            self.episode_limit = budget_ratio * self.n_objs + 1
        else:
            self.episode_limit = budget_ratio * 7 + 1
        self.early_stop = early_stop
        self.obs = np.zeros(self.n_obs)
        self.wo_obs = wo_obs  # Without Obs, False means that we need obs
        # MOEA/D Algorithm Related
        self.generator = RandomGenerator()
        self.selector = TournamentSelector(2)
        self.variator = None  # if variator is none, it will use the default variator
        self.evaluator = PlatypusConfig.default_evaluator
        self.baseline = baseline  # Used MOEA/D Default Operator Type
        self.moead_neighborhood_size = 30
        self.moead_neighborhood_maxsize = 30
        self.moead_eta = 2  # maximum number of replaced parent
        self.moead_delta = 0.8  # Use Neighbor or Whole Population, Important Parameters
        self.moead_weight_generator = random_weights
        # Not Important
        self.replay = replay
        self.replay_dir = replay_dir
        self.test = test
        self.save_history = save_history
        self.seed = seed
        self._init_static()

        self.replay_his = {
            "igd_his": [],
            "hv_his": [],
            "ndsort_ratio_his": [],
            "dis_his": [],
            "action_his": [],
            "population_his": []
        }

    def _init_problem(self):
        self.env_name = self.func_select[self.fun_index][0]
        self.n_objs = self.func_select[self.fun_index][1]
        self.problem = choose_problems(self.env_name, self.n_objs)
        self.lb, self.ub = get_variable_boundary(self.problem)
        self.problem_ref_points = self.problem.get_ref_set(n_ref_points=self.n_ref_points)
        self.igd_calculator = InvertedGenerationalDistance(reference_set=self.problem_ref_points)
        self.budget = self.population_size * self.budget_ratio * self.n_objs

    def _init_adaptive_weights(self):
        self.EP = []
        self.EP_MaxSize = int(self.population_size * 1.5)
        self.rate_update_weight = 0.05  # rate_update_weight * N = nus
        self.nus = int(
            self.rate_update_weight * self.population_size)  # maximal number of subproblems needed to be adjusted
        self.wag = 100  # adaptive iteration interval, Units are Iterations
        self.rate_evol = 0.8

    def _init_static(self):
        self.nfe = 0  # Discarded
        self.moead_generation = 0  # Discarded
        self.best_value = 1e6
        self.last_value = 1e6
        self.inital_value = None
        self.last_bonus = 0
        self.stag_count = 0
        self.stag_count_max = (self.budget // self.population_size) / 10

        self.hv_his = []
        self.hv_last5 = tianshou.utils.MovAvg(size=5)

        self.nds_ratio_his = []
        self.nds_ratio_last5 = tianshou.utils.MovAvg(size=5)

        self.ava_dist_his = []
        self.ava_dist_last5 = tianshou.utils.MovAvg(size=5)

        self.value_his = []

        self.hv_running = tianshou.utils.RunningMeanStd()
        self.nds_ratio_running = tianshou.utils.RunningMeanStd()
        self.ava_dist_running = tianshou.utils.RunningMeanStd()

        self.action_count = np.zeros(shape=(self.n_agents, self.n_actions), dtype=int)
        self.action_seq_his = []  # record the sequential history of action

        self.info_reward_his = []
        self.info_obs_his = []
        self.info_igd_his = []
        self.info_hv_his = []

    def moead_initial(self):
        # self.weights = self.moead_weight_generator(self.problem.nobjs, self.population_size)
        # self.get_uniform_points is faster than the self.moead_weight_generator,
        # but may can not generate enough weights
        self.weights = get_weights(self.n_objs, self.population_size)
        self.neighborhoods = []  # the i-th element save the index of the neighborhoods of it
        for i in range(self.population_size):
            sorted_weights = self.moead_sort_weights(self.weights[i], self.weights)
            self.neighborhoods.append(sorted_weights[:self.moead_neighborhood_maxsize])
            self.moead_update_ideal(self.population[i])

    def moead_step(self, action=None):
        """
        one step update in moea/d
        @param action: neighboor size; operator type; operator parameter
        :return:
        """
        if action is None:
            raise Exception("Action Is None.")
        if self.adaptive_open is False:
            action[3] = 0
        self.moead_neighborhood_size = neighbor_size(action[0])
        scale = operator_parameter(action[2])  # Operator Parameter
        if np.sum(self.action_count) / self.n_agents <= 1:
            self.moead_initial()
        subproblems = self.moead_get_subproblems()
        self.offspring_list = []
        for index in subproblems:
            mating_indices = self.moead_get_mating_indices(index)
            # Generate an offspring
            if self.baseline is False:
                if index in mating_indices:
                    mating_indices.remove(index)
                de_pool = np.random.choice(mating_indices, 5, replace=False)
                offspring = [self.variator.mutation.evolve(
                    operator_type(action[1], self.population, index, de_pool, scale, self.lb, self.ub))]
                # offspring = [operator_type(action[1], self.population, index, de_pool, scale, self.lb, self.ub)]
            else:
                parents = [self.population[index]] + [self.population[i] for i in
                                                      np.random.permutation(mating_indices)[:(self.variator.arity - 1)]]
                offspring = self.variator.evolve(parents)
            self.evaluate_all(offspring)
            self.offspring_list.extend(offspring)
            for child in offspring:
                self.moead_update_ideal(child)
                self.moead_update_solution(child, mating_indices)  # selection
        if action[3] > 1:
            raise Exception("action[3] > 1.")
        if self.moead_generation >= (self.rate_evol * self.budget_ratio * self.n_objs):
            if self.adaptive_open:
                if len(self.EP) == 0:
                    self.EP.extend(self.population)
                self.update_ep()
            if action[3] == 1 and (self.moead_generation % self.wag == 0):
                self.update_weight()
        self.moead_generation += 1

    def nondominated_solution(self, solutions):
        """
        Rank == 0 Indicates nondominated solution
        @param solutions:
        @return:
        """
        archive = Archive()  # An archive only containing non-dominated solutions.
        archive += solutions  # archive[0] is Solution Instance
        for solution in solutions:
            if solution in archive:
                solution.rank = 0
            else:
                solution.rank = 1

    def update_ep(self):
        # Update the current population by EP
        self.EP.extend(self.offspring_list)
        self.nondominated_solution(self.EP)
        self.EP = [e for e in self.EP if e.rank == 0]
        l = len(self.EP)
        if l <= self.EP_MaxSize:
            return
        # Delete the overcrowded solutions in EP
        dist = scipy.spatial.distance.cdist(
            [self.EP[i].objectives for i in range(l)],
            [self.EP[i].objectives for i in range(l)]
        )
        for i in range(l):
            dist[i][i] = np.inf
        dist.sort(axis=1)
        sub_dist = np.prod(dist[:, 0:self.n_objs], axis=1)  # find max self.EP_MaxSize item
        idx = np.argpartition(sub_dist, - self.EP_MaxSize)[-self.EP_MaxSize:]
        self.EP = list((itemgetter(*idx)(self.EP)))

    def update_weight(self):
        # Delete the overcrowded subproblems
        l_ep = len(self.EP)
        nus = min(l_ep, self.nus)
        dist = scipy.spatial.distance.cdist(
            [self.population[i].objectives for i in range(self.population_size)],
            [self.population[i].objectives for i in range(self.population_size)]
        )
        for i in range(self.population_size):
            dist[i][i] = np.inf
        dist.sort(axis=1)
        sub_dist = np.prod(dist[:, 0:self.n_objs], axis=1)
        idx = np.argpartition(sub_dist, -(self.population_size - nus))[-(self.population_size - nus):]
        self.population = list((itemgetter(*idx)(self.population)))
        self.weights = list((itemgetter(*idx)(self.weights)))
        # Add new subproblems
        l_p = len(self.population)
        dist = scipy.spatial.distance.cdist(
            [self.EP[i].objectives for i in range(l_ep)],
            [self.population[i].objectives for i in range(l_p)]
        )
        dist.sort(axis=1)
        sub_dist = np.prod(dist[:, 0:self.n_objs], axis=1)
        idx = np.argpartition(sub_dist, -nus)[-nus:]
        add_EP = list((itemgetter(*idx)(self.EP)))
        add_weights = []
        for e in add_EP:
            ans = np.asarray(e.objectives) - np.asarray(self.ideal_point)
            ans[ans < EPSILON] = 1
            ans = 1 / ans
            ans[ans == np.inf] = 1  # when f = z
            add_weights.append((ans / np.sum(ans)).tolist())
        self.population.extend(add_EP)
        self.weights.extend(add_weights)
        # Update the neighbor
        self.neighborhoods = []  # the i-th element save the index of the neighborhoods of it
        for i in range(self.population_size):
            sorted_weights = self.moead_sort_weights(self.weights[i], self.weights)
            self.neighborhoods.append(sorted_weights[:self.moead_neighborhood_maxsize])

    def update_igd(self, value):
        self.value_his.append(value)
        if value < self.best_value:
            self.stag_count = 0
            self.best_value = value
        else:
            self.stag_count += 1
        self.last_value = value

    def step(self, action):
        """
        Returns reward, terminated, info
        :type action: int or np.int
        :param action: for dimension: Neighbor Size, Operator Type, Operator Parameter, Adaptive Weights
                [x, x, x, x]: shape=n_agents
        :return: obs, reward, is_done, info
        """
        for id in range(self.n_agents):
            self.action_count[id][action[id]] = self.action_count[id][action[id]] + 1
        self.action_seq_his.append(action)
        self.moead_step(action)
        value = self.get_igd()
        self.update_igd(value)
        self.info_igd_his.append(self.last_value)  # Logger
        if np.sum(self.action_count[0]) % 50 == 0 and self.test:
            # if test, it will report basic information of this run
            print("Problem {} ite {},  action is {}, best igd is {}, last igd is {}".format(self.problem,
                                                                                            self.nfe, action,
                                                                                            self.best_value,
                                                                                            self.last_value))
        # if stop, then return the information
        if np.sum(self.action_count[0]) >= self.budget // self.population_size or \
                (self.stag_count > (self.budget // self.population_size) / 10 and self.early_stop):
            self.done = True
            if self.replay:
                self.update_replay()
        else:
            self.done = False
        if self.save_history:
            return self.done, {'best_igd': self.best_value,
                               'last_igd': self.last_value,
                               'igd_his': self.info_igd_his,
                               'reward_his': self.info_reward_his,
                               'obs_his': self.info_obs_his,
                               'hv_his': self.hv_his}
        else:
            return self.done, {'best_igd': self.best_value, 'last_igd': self.last_value}

    def reset(self):
        """ Returns initial observations and states """
        self.done = False
        self._init_problem()
        self._init_adaptive_weights()
        self._init_static()
        self.population = [self.generator.generate(self.problem) for _ in range(self.population_size)]
        self.evaluate_all(self.population)
        self.inital_value = self.get_igd()
        self.best_value = self.inital_value
        self.last_value = self.inital_value
        if self.n_objs < 5:
            tmp_population = [self.generator.generate(self.problem) for _ in range(int(1e5))]
            tmp_feasible = [s for s in tmp_population if s.constraint_violation == 0.0]
            self.evaluate_all(tmp_feasible)
            self.archive_maximum = (
                    1.1 * np.array([max([s.objectives[i] for s in tmp_feasible]) for i in range(self.n_objs)])).tolist()
            del tmp_population, tmp_feasible
        else:
            self.archive_maximum = [max([s.objectives[i] for s in self.population]) for i in range(self.n_objs)]
        self.archive_minimum = [min([s.objectives[i] for s in self.population]) for i in range(self.n_objs)]
        self.ideal_point = copy.deepcopy(self.archive_minimum)
        if self.variator is None:
            self.variator = GAOperator(SBX(probability=1.0, distribution_index=20.0),
                                       PM(probability=1 / self.problem.nvars))
        # Change Task Function
        self.fun_index += 1
        if self.fun_index == len(self.func_select):
            self.fun_index = 0
            random.shuffle(self.func_select)
        return self.obs

    def close(self):
        self.reset()

    def seed(self):
        return self.seed

    def update_replay(self):
        self.replay_his["igd_his"].append(self.info_igd_his)
        self.replay_his["hv_his"].append(self.hv_his)
        self.replay_his["ndsort_ratio_his"].append(self.nds_ratio_his)
        self.replay_his["dis_his"].append(self.ava_dist_his)
        self.replay_his["action_his"].append(self.action_seq_his)
        self.replay_his["population_his"].append(self.population)

    def save_replay(self):
        if not os.path.exists(self.replay_dir):
            os.umask(0)
            os.makedirs(self.replay_dir, mode=0o777)
        token = f"ban{str(self.ban_agent)}_R{str(self.reward_type)}_replay.npz"
        replay_path = os.path.join(self.replay_dir, token)
        np.savez(file=replay_path, info_stack=self.replay_his)

    def get_ref_set(self):
        return self.problem_ref_points

    def get_uniform_points(self, n_points):
        """
        return a list of points
        :param n_points:
        :return:
        """
        uniformPoints, _ = gt.crtup(self.n_objs, n_points)
        return uniformPoints.tolist()

    def evaluate_all(self, solutions):
        unevaluated = [s for s in solutions if not s.evaluated]

        jobs = [_EvaluateJob(s) for s in unevaluated]
        results = self.evaluator.evaluate_all(jobs)

        # if needed, update the original solution with the results
        for i, result in enumerate(results):
            if unevaluated[i] != result.solution:
                unevaluated[i].variables[:] = result.solution.variables[:]
                unevaluated[i].objectives[:] = result.solution.objectives[:]
                unevaluated[i].constraints[:] = result.solution.constraints[:]
                unevaluated[i].constraint_violation = result.solution.constraint_violation
                unevaluated[i].feasible = result.solution.feasible
                unevaluated[i].evaluated = result.solution.evaluated

    def moead_update_ideal(self, solution):
        for i in range(self.problem.nobjs):
            self.ideal_point[i] = min(self.ideal_point[i], solution.objectives[i])

    def moead_calculate_fitness(self, solution, weights):
        return chebyshev(solution, self.ideal_point, weights)

    def moead_update_solution(self, solution, mating_indices):
        """
        repair solution, make constraint satisfiable
        :param solution:
        :param mating_indices:
        :return:
        """
        c = 0
        random.shuffle(mating_indices)

        for i in mating_indices:
            candidate = self.population[i]
            weights = self.weights[i]
            replace = False

            if solution.constraint_violation > 0.0 and candidate.constraint_violation > 0.0:
                if solution.constraint_violation < candidate.constraint_violation:
                    replace = True
            elif candidate.constraint_violation > 0.0:
                replace = True
            elif solution.constraint_violation > 0.0:
                pass
            elif self.moead_calculate_fitness(solution, weights) < self.moead_calculate_fitness(candidate, weights):
                replace = True

            if replace:
                self.population[i] = copy.deepcopy(solution)
                c = c + 1

            if c >= self.moead_eta:
                break

    @staticmethod
    def moead_sort_weights(base, weights):
        """Returns the index of weights nearest to the base weight."""

        def compare(weight1, weight2):
            dist1 = math.sqrt(sum([math.pow(base[i] - weight1[1][i], 2.0) for i in range(len(base))]))
            dist2 = math.sqrt(sum([math.pow(base[i] - weight2[1][i], 2.0) for i in range(len(base))]))

            if dist1 < dist2:
                return -1
            elif dist1 > dist2:
                return 1
            else:
                return 0

        sorted_weights = sorted(enumerate(weights), key=functools.cmp_to_key(compare))
        return [i[0] for i in sorted_weights]

    def moead_get_subproblems(self):
        """
        Determines the subproblems to search.
        If :code:`utility_update` has been set, then this method follows the
        utility-based moea/D search.
        Otherwise, it follows the original moea/D specification.
        """
        indices = list(range(self.population_size))
        random.shuffle(indices)
        return indices

    def moead_get_mating_indices(self, index):
        """Determines the mating indices.

        Returns the population members that are considered during mating.  With
        probability :code:`delta`, the neighborhood is returned.  Otherwise,
        the entire population is returned.
        """
        if random.uniform(0.0, 1.0) <= self.moead_delta:
            return self.neighborhoods[index][:self.moead_neighborhood_size]
        else:
            return list(range(self.population_size))

    def get_igd(self):
        return self.igd_calculator.calculate(self.population)
