'''
MKP (end-to-end)
'''
import sys
import pickle
import types
import warnings
import time
from typing import Any
import numpy as np
from prompts import GetPrompts


class MKP():
    """Evaluator for multi knapsack problem."""

    def __init__(self,
                 problem_size=100,
                 running_time=30,
                 dirname=None,  # or str e.g., 'dataset_mcts'
                 mode='train',
                 debug=False,
                 **kwargs):

        self.time = running_time
        self.dirname = dirname
        self.debug = debug
        self.problem_size = problem_size
        self.prompts = GetPrompts()
        if self.dirname is None:  # NOTE: unused
            # Load from files
            with open(f'all_data_{mode}.pkl', 'rb') as f:
                all_data = pickle.load(f)
            self._datasets = all_data[self.problem_size]
            self.n_instance = len(self._datasets)
        else:
            if self.dirname == 'dataset_mcts':
                data = np.load(f'{self.dirname}/{mode}{problem_size}_dataset.npz')
                values_all, weights_all = data['prizes'], data['weights']
                self._datasets = []
                self.n_instance = weights_all.shape[0]
                self.constraints = np.ones(weights_all.shape[-1])
                for i in range(self.n_instance):
                    values = values_all[i]
                    weights = weights_all[i].transpose()
                    self._datasets.append((values, weights, self.constraints))
        self.m = self.constraints.size

    def get_seed_algorithms(self):
        algorithms = {
            'constructive': "...",
        }
        return algorithms

    # def evaluate_program(self, program_str: str, callable_func: callable) -> Any | None:
    #     return self.evaluate(callable_func)

    def calObjectives(self, alg, reduction=None):
        '''
        alg: a Python function coded by LLM
        '''
        start_time = time.time()
        total_value = 0

        for i, (values, weights, constraints) in enumerate(self._datasets):

            if time.time()-start_time > self.time:
                # print(i)
                return None, None

            if reduction is None:
                items = alg.pack_items(values, weights, constraints)
            else:
                input_A = (values, weights, constraints)
                input_B = reduction.convert_input_A_to_B(*input_A)
                try:
                    solution_B = alg.solve_B(input_B)
                except TypeError as e:
                    solution_B = alg.solve_B(*input_B)
                items = np.array(reduction.convert_solution_B_to_A(solution_B))
                # print(items)
                if (items.size == 0) or (np.unique(items).size != items.size) or (np.any(np.sum(weights[:, items], axis=1) > constraints)):  # invalid packing
                    return None, None
                else:
                    if self.debug and i == 0:
                        with open(f'example_bag.pkl', 'wb') as f:
                            pickle.dump(items, f)

            total_value += np.sum(values[items])

        ave_value = total_value / self.n_instance
        # print("average value: ",ave_value)
        return ave_value, time.time()-start_time


    def evaluate(self, code_string, reduction=None):
        '''
        Return: fitness score (higher is better) for the LLM-generated heuristic (as code)
        '''
        #try:
        #Suppress warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            if reduction is not None:
            # Create a new module object for the reduction
                reduction_module = types.ModuleType("reduction_module")
                
                # Execute the code string in the new module's namespace
                exec(reduction, reduction_module.__dict__)

                # Add the module to sys.modules so it can be imported
                sys.modules[reduction_module.__name__] = reduction_module
            else:
                reduction_module = None

        # Create a new module object for the heuristic
            heuristic_module = types.ModuleType("heuristic_module")
            
            # Execute the code string in the new module's namespace
            exec(code_string, heuristic_module.__dict__)

            # Add the module to sys.modules so it can be imported
            sys.modules[heuristic_module.__name__] = heuristic_module

            fitness, runtime = self.calObjectives(heuristic_module, reduction_module)

            return fitness, runtime

