'''
OBP (end-to-end)
'''
import sys
import pickle
import types
import warnings
import time
from typing import Any
import numpy as np
from prompts import GetPrompts


class OBP():
    """Evaluator for online bin packing problem."""

    def __init__(self,
                 problem_config=None,  # train & test: '1k_100', '1k_500', '5k_100', '5k_500'; test only: '10k_100', '10k_500'
                 running_time=30,
                 dirname=None,  # or str e.g., 'dataset_mcts'
                 mode='train',
                 debug=False,
                 **kwargs):

        self.time = running_time
        self.dirname = dirname
        self.debug = debug
        if problem_config is not None:
            self.problem_config = problem_config.split('_')
        self.prompts = GetPrompts()
        if self.dirname is None:  # NOTE: unused
            # Load from files
            with open(f'all_data_{mode}.pkl', 'rb') as f:
                all_data = pickle.load(f)
            self._datasets = all_data[(self.problem_config[0], int(self.problem_config[1]))]
            self.n_instance = len(self._datasets)
        else:
            self._datasets = []
            if self.dirname == 'dataset_mcts':
                if mode == 'train':
                    filenames = ['1k_train1','1k_train2','5k_train1','5k_train2']
                    self.n_instance = len(filenames)
                    for filename in filenames:
                        with open(f'{self.dirname}/weibull_{filename}.pickle', 'rb') as f:
                            data = pickle.load(f)
                        items = data['train_0']['items']  # np.ndarray (1D)
                        capacity = data['train_0']['capacity']  # int
                        opt = data['l1_bound']  # float
                        self._datasets.append((items, capacity, opt))
                elif mode == 'test':
                    with open(f'{self.dirname}/weibull_{self.problem_config[0]}_test_{self.problem_config[1]}.pickle', 'rb') as f:
                        data = pickle.load(f)
                    for name, instance in data.items():
                        if name == 'l1_bound':
                            continue
                        items = instance['items']  # np.ndarray (1D)
                        capacity = instance['capacity']  # int
                        opt = data['l1_bound']  # float
                        self._datasets.append((items, capacity, opt))
                    self.n_instance = len(self._datasets)

    def get_seed_algorithms(self):
        algorithms = {
            'constructive': "...",
        }
        return algorithms

    # def evaluate_program(self, program_str: str, callable_func: callable) -> Any | None:
    #     return self.evaluate(callable_func)

    def calObjectives(self, alg, reduction=None):
        '''
        alg: a Python function coded by LLM
        '''
        start_time = time.time()
        gaps = np.zeros(self.n_instance)

        for i, (items, capacity, opt) in enumerate(self._datasets):

            if time.time()-start_time > self.time:
                # print(i)
                return None, None

            n_items = items.size
            bins = np.array([capacity for _ in range(n_items)])  # unpacked bins
            if reduction is None:
                bins = alg.pack_items(items, bins)
            else:
                input_A = (items, bins)
                input_B = reduction.convert_input_A_to_B(*input_A)
                solution_B = alg.solve_B(input_B)
                bins = reduction.convert_solution_B_to_A(solution_B)  # packed bins
                print(capacity, bins.size, np.sum(items), np.sum(capacity - bins))
                print(bins)
                if bins.size != n_items or np.any(bins > capacity) or np.sum(items) != np.sum(capacity - bins):  # invalid packing
                    return None, None
                else:
                    if self.debug and i == 0:
                        with open(f'example_packing.pkl', 'wb') as f:
                            pickle.dump(bins, f)

            n_bins_used = (bins != capacity).sum()
            gaps[i] = (n_bins_used/opt - 1)*100

        ave_gap = np.average(gaps)
        # print("average gap: ",ave_gap)
        return -ave_gap, time.time()-start_time


    def evaluate(self, code_string, reduction=None):
        '''
        Return: fitness score (higher is better) for the LLM-generated heuristic (as code)
        '''
        #try:
        #Suppress warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            if reduction is not None:
            # Create a new module object for the reduction
                reduction_module = types.ModuleType("reduction_module")
                
                # Execute the code string in the new module's namespace
                exec(reduction, reduction_module.__dict__)

                # Add the module to sys.modules so it can be imported
                sys.modules[reduction_module.__name__] = reduction_module
            else:
                reduction_module = None

        # Create a new module object for the heuristic
            heuristic_module = types.ModuleType("heuristic_module")
            
            # Execute the code string in the new module's namespace
            exec(code_string, heuristic_module.__dict__)

            # Add the module to sys.modules so it can be imported
            sys.modules[heuristic_module.__name__] = heuristic_module

            fitness, runtime = self.calObjectives(heuristic_module, reduction_module)

            return fitness, runtime

