'''
CVRP (end-to-end)
'''
import sys
import pickle
import types
import warnings
import time
from typing import Any
import numpy as np
from prompts import GetPrompts


class CVRP():
    """Evaluator for traveling salesman problem."""

    def __init__(self,
                 problem_size=50,
                 running_time=30,
                 dirname=None,  # or str e.g., 'dataset_mcts'
                 mode='train',
                 debug=False,
                 **kwargs):

        self.time = running_time
        self.dirname = dirname
        self.debug = debug
        self.problem_size = problem_size
        self.prompts = GetPrompts()
        if self.dirname is None:
            # Load from files
            with open(f'all_data_{mode}.pkl', 'rb') as f:
                all_data = pickle.load(f)
            self._datasets = all_data[self.problem_size]
            self.n_instance = len(self._datasets)
        else:
            if self.dirname == 'dataset_mcts':
                data = np.load(f'{self.dirname}/{mode}{problem_size}_dataset.npy')
                self._datasets = []
                self.n_instance = data.shape[0]
                for i in range(self.n_instance):
                    coordinates = data[i,:,-2:]
                    distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
                    demands = data[i,:,0]
                    capacity = 50
                    self._datasets.append((coordinates, distances, demands, capacity))

    def get_seed_algorithms(self):
        algorithms = {
            'constructive': "...",
        }
        return algorithms

    # def evaluate_program(self, program_str: str, callable_func: callable) -> Any | None:
    #     return self.evaluate(callable_func)

    def route_cost(self, coord_matrix, route):
        """
        Args:
            route (Numpy 1D array): first entry is 0 (depot), and last entry is the last visited node
        """
        cost = 0
        for j in range(len(route) - 1):
            cost += np.linalg.norm(coord_matrix[int(route[j])] - coord_matrix[int(route[j + 1])])
        cost += np.linalg.norm(coord_matrix[int(route[-1])] - coord_matrix[int(route[0])])  # distance from last visited node back to depot for good
        return cost

    def calObjectives(self, alg, reduction=None):
        '''
        alg: a Python function coded by LLM
        '''
        start_time = time.time()
        dis = np.ones(self.n_instance)

        for i, (coord_matrix, distance_matrix, demands, capacity) in enumerate(self._datasets):

            if time.time()-start_time > self.time:
                # print(i)
                return None, None

            if reduction is None:
                route = alg.find_route(coord_matrix, distance_matrix, demands, capacity)
            else:
                input_A = (coord_matrix, distance_matrix, demands, capacity)
                input_B = reduction.convert_input_A_to_B(*input_A)
                try:
                    solution_B = alg.solve_B(input_B)
                except TypeError as e:
                    solution_B = alg.solve_B(*input_B)
                routes = reduction.convert_solution_B_to_A(solution_B)
                customers = np.array([])
                for subroute in routes:
                    subroute = np.array(subroute)
                    customers = np.concatenate((customers, subroute))
                    '''
                    1. each customer from a subroute is visited exactly once (no revisits)
                    2. sum of demands from customers served by a subroute must not exceed vehicle capacity
                    3. all customers are visited exactly once (no revisits)
                    '''
                    if (np.unique(subroute).size != subroute.size) or (np.sum(demands[subroute]) > capacity) or (np.unique(customers).size != customers.size):
                        return None, None
                route = np.concatenate([np.concatenate((np.zeros(1,), subroute)) for subroute in routes]).astype(int)
                # print(i+1, np.unique(route).size, np.min(route), np.max(route))
                if np.unique(route).size != self.problem_size+1:  # all customers are visited
                    return None, None
                else:
                    if self.debug and i == 0:
                        with open(f'example_tour.pkl', 'wb') as f:
                            pickle.dump(route, f)

            LLM_dis = self.route_cost(coord_matrix, route)

            dis[i] = LLM_dis

        ave_dis = np.average(dis)
        # print("average dis: ",ave_dis)
        return -ave_dis, time.time()-start_time


    def evaluate(self, code_string, reduction=None):
        '''
        Return: fitness score (higher is better) for the LLM-generated heuristic (as code)
        '''
        #try:
        #Suppress warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            if reduction is not None:
            # Create a new module object for the reduction
                reduction_module = types.ModuleType("reduction_module")
                
                # Execute the code string in the new module's namespace
                exec(reduction, reduction_module.__dict__)

                # Add the module to sys.modules so it can be imported
                sys.modules[reduction_module.__name__] = reduction_module
            else:
                reduction_module = None

        # Create a new module object for the heuristic
            heuristic_module = types.ModuleType("heuristic_module")
            
            # Execute the code string in the new module's namespace
            exec(code_string, heuristic_module.__dict__)

            # Add the module to sys.modules so it can be imported
            sys.modules[heuristic_module.__name__] = heuristic_module

            fitness, runtime = self.calObjectives(heuristic_module, reduction_module)

            return fitness, runtime
