'''
TSP (end-to-end)
'''
import sys
import pickle
import types
import warnings
import time
from typing import Any
import numpy as np
from prompts import GetPrompts


class TSP():
    """Evaluator for traveling salesman problem."""

    def __init__(self,
                 problem_size=50,
                 running_time=30,
                 dirname=None,  # or str e.g., 'dataset_mcts'
                 mode='train',
                 debug=False,
                 **kwargs):

        self.time = running_time
        self.dirname = dirname
        self.debug = debug
        self.problem_size = problem_size
        self.prompts = GetPrompts(problem_size)
        if self.dirname is None:
            # Load from files
            with open(f'all_data_{mode}.pkl', 'rb') as f:
                all_data = pickle.load(f)
            self._datasets = all_data[self.problem_size]
            self.n_instance = len(self._datasets)
        else:
            if self.dirname == 'dataset_mcts' or 'dataset_mcts_aco':
                data = np.load(f'{self.dirname}/{mode}{problem_size}_dataset.npy')
                self._datasets = []
                self.n_instance = data.shape[0]
                for i in range(self.n_instance):
                    coordinates = data[i,:,:]
                    distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
                    self._datasets.append((coordinates, distances))

    def get_seed_algorithms(self):
        algorithms = {
            'constructive': "...",
        }
        return algorithms

    # def evaluate_program(self, program_str: str, callable_func: callable) -> Any | None:
    #     return self.evaluate(callable_func)

    def route_cost(self, coord_matrix, route, problem_size):
        cost = 0
        for j in range(problem_size - 1):
            cost += np.linalg.norm(coord_matrix[int(route[j])] - coord_matrix[int(route[j + 1])])
        cost += np.linalg.norm(coord_matrix[int(route[-1])] - coord_matrix[int(route[0])])
        return cost

    def generate_neighborhood_matrix(self, instance):
        instance = np.array(instance)
        n = len(instance)
        neighborhood_matrix = np.zeros((n, n), dtype=int)

        for i in range(n):
            distances = np.linalg.norm(instance[i] - instance, axis=1)
            sorted_indices = np.argsort(distances)  # sort indices based on distances
            neighborhood_matrix[i] = sorted_indices

        return neighborhood_matrix

    def calObjectives(self, alg, reduction=None):
        '''
        alg: a Python function coded by LLM
        '''
        start_time = time.time()
        dis = np.ones(self.n_instance)

        for i, (coord_matrix, distance_matrix) in enumerate(self._datasets):

            if time.time()-start_time > self.time:
                # print(i)
                return None, None

            if reduction is None:
                route = alg.find_route(coord_matrix, distance_matrix)
            else:
                input_A = (coord_matrix, distance_matrix)
                input_B = reduction.convert_input_A_to_B(*input_A)
                solution_B = alg.solve_B(input_B)
                route = reduction.convert_solution_B_to_A(solution_B)
                # print(i+1, np.unique(route).size, np.min(route), np.max(route))
                if route.size != self.problem_size or np.unique(route).size != self.problem_size:  # invalid tours
                    return None, None
                else:
                    if self.debug and i == 0:
                        with open(f'example_tour.pkl', 'wb') as f:
                            pickle.dump(route, f)

            LLM_dis = self.route_cost(coord_matrix, route, self.problem_size)

            dis[i] = LLM_dis

        ave_dis = np.average(dis)
        # print("average dis: ",ave_dis)
        return -ave_dis, time.time()-start_time


    def evaluate(self, code_string, reduction=None):
        '''
        Return: fitness score (higher is better) for the LLM-generated heuristic (as code)
        '''
        #try:
        #Suppress warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            if reduction is not None:
            # Create a new module object for the reduction
                reduction_module = types.ModuleType("reduction_module")
                
                # Execute the code string in the new module's namespace
                exec(reduction, reduction_module.__dict__)

                # Add the module to sys.modules so it can be imported
                sys.modules[reduction_module.__name__] = reduction_module
            else:
                reduction_module = None

        # Create a new module object for the heuristic
            heuristic_module = types.ModuleType("heuristic_module")
            
            # Execute the code string in the new module's namespace
            exec(code_string, heuristic_module.__dict__)

            # Add the module to sys.modules so it can be imported
            sys.modules[heuristic_module.__name__] = heuristic_module

            fitness, runtime = self.calObjectives(heuristic_module, reduction_module)

            return fitness, runtime


if __name__ == '__main__':
    import sys

    print(sys.path)


    def select_next_node(current_node: int, destination_node: int, unvisited_nodes: np.ndarray, distance_matrix: np.ndarray) -> int:
        """
        Design a novel algorithm to select the next node in each step.

        Args:
        current_node: ID of the current node.
        destination_node: ID of the destination node.
        unvisited_nodes: Array of IDs of unvisited nodes.
        distance_matrix: Distance matrix of nodes.

        Return:
        ID of the next node to visit.
        """
        distances_to_destination = distance_matrix[current_node][unvisited_nodes]

        # Find the index of the unvisited node with the smallest distance to the destination
        next_node_index = np.argmin(distances_to_destination)

        # Get the ID of the next node to visit
        next_node = unvisited_nodes[next_node_index]

        return next_node


    tsp = TSP()
    # tsp.evaluate_program('_', select_next_node)
