import math
import multiprocessing

import cppimport.import_hook



def worker(problem, input_queue, result_queue):
    # This function represents each worker process
    # It loads the instance when created and then waits for tasks

    if problem == "cvrp":
        from .cpp.cvrp import SISRs
    elif problem == "vrptw":
        from .cpp.vrptw import SISRs
    elif problem == "pcvrp":
        from .cpp.pcvrp import SISRs

    instances = []
    solutions = []
    solution_costs = []
    tours = []

    try:
        while True:
            # Receive runtime limit from the main process
            mode, data = input_queue.get()

            if mode == "new_instance":
                instances = []
                solutions = []
                solution_costs = []
                tours = []

                if problem == "cvrp":
                    problem_size, capacity, depot_node_demand_np, depot_node_xy_np = data
                elif problem == "vrptw":
                    problem_size, capacity, depot_node_demand_np, depot_node_xy_np, depot_node_tw_np, depot_node_sd_np = data
                elif problem == "pcvrp":
                    problem_size, capacity, depot_node_demand_np, depot_node_xy_np, depot_node_prizes_np = data
                else:
                    raise NotImplementedError

                for i in range(depot_node_xy_np.shape[0]):
                    if problem == "cvrp":
                        instance = SISRs.Instance(problem_size, capacity[i], depot_node_demand_np[i], depot_node_xy_np[i])
                    elif problem == "vrptw":
                        instance = SISRs.Instance(problem_size, capacity[i],
                                                 depot_node_demand_np[i],
                                                 depot_node_tw_np[i, :, 0],
                                                 depot_node_tw_np[i, :, 1],
                                                 depot_node_sd_np[i],
                                                 depot_node_xy_np[i])
                    elif problem == "pcvrp":
                        instance = SISRs.Instance(problem_size, capacity[i], depot_node_demand_np[i], depot_node_xy_np[i], depot_node_prizes_np[i])
                    else:
                        raise NotImplementedError
                    # duplicate code below...
                    if problem_size < 500:
                        solution = SISRs.create_starting_solution(instance, 50, int(problem_size * 0.15))
                    else:
                        solution = SISRs.create_starting_solution(instance, 100, int(problem_size * 0.25))

                    instances.append(instance)
                    solutions.append(solution)
                    solution_costs.append(solution.totalCosts)
                    tours.append(solution.getTourList())

                result_queue.put([solution_costs, tours])

            elif mode == "remove_recreate":
                selected_nodes, recreate_n, T, beta, insert_in_new_tours_only, search_mode = data

                candidate_costs = []
                assert len(solutions) == selected_nodes.shape[0]
                for i in range(len(solutions)):
                    if search_mode == "allImp":
                        best_soln, _ = SISRs.remove_recreate_allImp(solutions[i],
                                                                           selected_nodes[i],
                                                                           beta, recreate_n, T, insert_in_new_tours_only)
                    elif search_mode == "singleImp":
                        best_soln, c_costs = SISRs.remove_recreate_singleImp(solutions[i],
                                                                          selected_nodes[i],
                                                                          beta, recreate_n, False, insert_in_new_tours_only)
                        candidate_costs.append(c_costs)

                    solutions[i] = best_soln
                    solution_costs[i] = best_soln.totalCosts
                    tours[i] = best_soln.getTourList()


                result_queue.put([candidate_costs, solution_costs, tours])

    except Exception as error:
        # handle the exception
        print("An exception occurred:", error)


class InstanceSet:

    def __init__(self, problem, use_multiprocessing=True, num_processes=8):
        self.batch_size = None
        self.tours = []
        self.costs = []

        # only used if use_multiprocessing is false
        self._instances = []
        self._solutions = []

        self.use_multiprocessing = use_multiprocessing

        if use_multiprocessing:
            self.processes = []
            self.num_processes = num_processes
            for i in range(self.num_processes):
                input_queue = multiprocessing.Queue()
                output_queue = multiprocessing.Queue()
                p = multiprocessing.Process(target=worker, args=(problem, input_queue, output_queue))
                p.start()
                self.processes.append([p, input_queue, output_queue])
        else:
            global SISRs
            if problem == "cvrp":
                from .cpp.cvrp import SISRs
            elif problem == "vrptw":
                from .cpp.vrptw import SISRs
            elif problem == "pcvrp":
                from .cpp.pcvrp import SISRs
            else:
                raise NotImplementedError

    def __del__(self):
        if self.use_multiprocessing:
            for p in self.processes:
                p[0].terminate()

    def init_instances(self, problem_data):
        if self.use_multiprocessing:
            return self.init_instances_mp(problem_data)
        else:
            return self.init_instances_sp(problem_data)

    def remove_recreate(self, selected_nodes, recreate_n, mode, T=0, beta=0.0, insert_in_new_tours_only=True):
        if self.use_multiprocessing:
            return self.remove_recreate_mp(selected_nodes, recreate_n, mode, T, beta, insert_in_new_tours_only)
        else:
            return self.remove_recreate_sp(selected_nodes, recreate_n, mode, T, beta, insert_in_new_tours_only)

    def init_instances_mp(self, problem_data):
        self.batch_size = problem_data.depot_node_xy.shape[0]

        instance_idx_start = 0
        instances_per_process = math.ceil(self.batch_size / self.num_processes)
        for (p, p_in, p_out) in self.processes:
            if problem_data.problem_name == "cvrp":
                p_data = [problem_data.problem_size, problem_data.capacity,
                          problem_data.depot_node_demand[instance_idx_start:instance_idx_start + instances_per_process],
                          problem_data.depot_node_xy[instance_idx_start:instance_idx_start + instances_per_process]]
            elif problem_data.problem_name == "vrptw":
                p_data = [problem_data.problem_size, problem_data.capacity,
                          problem_data.depot_node_demand[instance_idx_start:instance_idx_start + instances_per_process],
                          problem_data.depot_node_xy[instance_idx_start:instance_idx_start + instances_per_process],
                          problem_data.depot_node_tw[instance_idx_start:instance_idx_start + instances_per_process],
                          problem_data.depot_node_sd[instance_idx_start:instance_idx_start + instances_per_process]]
            elif problem_data.problem_name == "pcvrp":
                p_data = [problem_data.problem_size, problem_data.capacity,
                          problem_data.depot_node_demand[instance_idx_start:instance_idx_start + instances_per_process],
                          problem_data.depot_node_xy[instance_idx_start:instance_idx_start + instances_per_process],
                          problem_data.depot_node_prizes[instance_idx_start:instance_idx_start + instances_per_process]]
            else:
                raise NotImplementedError
            instance_idx_start += instances_per_process
            p_in.put(["new_instance", p_data])

        self.tours = []
        self.costs = []
        for (p, p_in, p_out) in self.processes:
            costs, tours = p_out.get()
            self.costs.extend(costs)
            self.tours.extend(tours)

    def remove_recreate_mp(self, selected_nodes, recreate_n, mode, T, beta, insert_in_new_tours_only):
        instance_idx_start = 0
        instances_per_process = math.ceil(self.batch_size / self.num_processes)
        for (p, p_in, p_out) in self.processes:
            p_data = [selected_nodes[instance_idx_start:instance_idx_start + instances_per_process], recreate_n, T, beta, insert_in_new_tours_only, mode]
            instance_idx_start += instances_per_process
            p_in.put(["remove_recreate", p_data])

        self.tours = []
        self.costs = []
        candidate_costs_set = []
        for (p, p_in, p_out) in self.processes:
            candidate_costs, costs, t = p_out.get()
            self.costs.extend(costs)
            self.tours.extend(t)
            candidate_costs_set.extend(candidate_costs)

        return candidate_costs_set

    def init_instances_sp(self, problem_data):
        self.batch_size = problem_data.depot_node_xy.shape[0]
        problem_size = problem_data.problem_size

        self._instances = []
        self._solutions = []
        self.costs = []
        self.tours = []

        for i in range(self.batch_size):
            if problem_data.problem_name == "cvrp":
                instance = SISRs.Instance(problem_size, problem_data.capacity[i], problem_data.depot_node_demand[i],
                                         problem_data.depot_node_xy[i])
            elif problem_data.problem_name == "vrptw":
                instance = SISRs.Instance(problem_size, problem_data.capacity[i], problem_data.depot_node_demand[i],
                                            problem_data.depot_node_tw[i, :, 0], problem_data.depot_node_tw[i, :, 1],
                                                problem_data.depot_node_sd[i],
                                         problem_data.depot_node_xy[i])
            elif problem_data.problem_name == "pcvrp":
                instance = SISRs.Instance(problem_size, problem_data.capacity[i], problem_data.depot_node_demand[i],
                                            problem_data.depot_node_xy[i], problem_data.depot_node_prizes[i])
            else:
                raise NotImplementedError

            # duplicate code above...
            if problem_size < 500:
                solution = SISRs.create_starting_solution(instance, 50, int(problem_size * 0.15))
            else:
                solution = SISRs.create_starting_solution(instance, 100, int(problem_size * 0.25))

            self._solutions.append(solution)
            self._instances.append(instance)
            self.costs.append(solution.totalCosts)
            self.tours.append(solution.getTourList())

    def remove_recreate_sp(self, selected_nodes, recreate_n, mode, T, beta, insert_in_new_tours_only):

        candidate_costs_set = []
        for i in range(self.batch_size):
            if mode == "allImp":
                # print(selected_nodes[i])
                # print(selected_nodes[i].shape)
                best_soln, _ = SISRs.remove_recreate_allImp(self._solutions[i],
                                                           selected_nodes[i],
                                                           beta, recreate_n, T, insert_in_new_tours_only)
            elif mode == "singleImp":
                best_soln, candidate_costs = SISRs.remove_recreate_singleImp(self._solutions[i],
                                                                    selected_nodes[i],
                                                                    beta, recreate_n, False, insert_in_new_tours_only)
                candidate_costs_set.append(candidate_costs)

            self._solutions[i] = best_soln
            self.costs[i] = best_soln.totalCosts
            self.tours[i] = best_soln.getTourList()

        return candidate_costs_set

    def getTours(self):
        return self.tours

    def get_solution(self, idx):
        return self._solutions[idx]

    def set_solution(self, idx, sol):
        self._solutions[idx] = sol
        self.costs[idx] = sol.totalCosts
        self.tours[idx] = sol.getTourList()

