''' Test a heuristic from a specific exp folder and generation
Usage: python [FILENAME].py [path] [test_problem_config(s)] [gen no.] [heuristic no.]
Note: this only applies for external test data from MCTS-AHD and reduc=True
Note: if plural, separate using '+'
Example: python [FILENAME].py 50_12/results/pops_mcts_evored_1 1k_100+1k_500+5k_100+5k_500+10k_100+10k_500 10 2
'''
import numpy as np
import pickle, json
import types
import warnings
from multiprocessing import Process, Value
import sys, os
import copy
import time
import pandas as pd
import numpy as np
import networkx as nx
from typing import Union


def get_valid_bin_indices(item: float, bins: np.ndarray) -> np.ndarray:
    """Returns indices of bins in which item can fit."""
    return np.nonzero((bins - item) >= 0)[0]


class Test():
    def __init__(self, data, problem_config: str, is_online: bool):
        self.problem_config = problem_config
        self.is_online = is_online
        self.n_instance = len(data)
        self._datasets = data

    def calObjectives(self, alg, reduction=None) -> float:

        gaps = np.zeros(self.n_instance)
        objs = np.zeros(self.n_instance)

        for i, (items, capacity, opt) in enumerate(self._datasets):

            n_items = items.size
            bins = np.array([capacity for _ in range(n_items)])  # unpacked bins
            if reduction is None:
                bins = alg.pack_items(items, bins)
            else:
                if self.is_online:
                    for item in items:
                        valid_bin_indices = get_valid_bin_indices(item, bins)
                        input_A = (item, bins[valid_bin_indices])
                        input_B = reduction.convert_input_A_to_B(*input_A)
                        try:
                            solution_B = alg.solve_B(input_B)
                        except TypeError as e:
                            solution_B = alg.solve_B(*input_B)
                        ## eval 1
                        # bin_index = valid_bin_indices[reduction.convert_solution_B_to_A(solution_B)]
                        ## eval 2
                        scores = reduction.convert_solution_B_to_A(solution_B)
                        bin_index = valid_bin_indices[np.argmax(scores)]
                        bins[bin_index] -= item
                else:
                    input_A = (items, bins)
                    input_B = reduction.convert_input_A_to_B(*input_A)
                    solution_B = alg.solve_B(input_B)
                    bins = reduction.convert_solution_B_to_A(solution_B)  # packed bins

            if bins.size != n_items or np.any(bins > capacity) or np.sum(items) != np.sum(capacity - bins):  # invalid packing
                print(i+1, flush=True)
            else:
                n_bins_used = (bins != capacity).sum()
                gaps[i] = (n_bins_used/opt - 1)*100
                objs[i] = n_bins_used

        ave_gap = np.average(gaps)
        # print("average gap: ",ave_gap)
        ave_obj = np.average(objs)
        print(ave_obj)
        return -ave_gap


    def evaluate_new(self, code_string, objective, reduction=None):

        if reduction is not None:
        # Create a new module object for the reduction
            reduction_module = types.ModuleType("reduction_module")
            
            # Execute the code string in the new module's namespace
            exec(reduction, reduction_module.__dict__)

            # Add the module to sys.modules so it can be imported
            sys.modules[reduction_module.__name__] = reduction_module
        else:
            reduction_module = None

        # Create a new module object
        heuristic_module = types.ModuleType("heuristic_module")
        
        # Execute the code string in the new module's namespace
        exec(code_string, heuristic_module.__dict__)

        # Add the module to sys.modules so it can be imported
        sys.modules[heuristic_module.__name__] = heuristic_module

        objective.value = self.calObjectives(heuristic_module, reduction_module)


def load_external_data(dirname, problem_config, is_online):
    if is_online:  # online
        problem_config = problem_config.split('_')
        with open(f'{dirname}/weibull_{problem_config[0]}_test_{problem_config[1]}.pickle', 'rb') as f:
            raw_data = pickle.load(f)
        data = []
        for name, instance in raw_data.items():
            if name == 'l1_bound':
                continue
            items = instance['items']  # np.ndarray (1D)
            capacity = instance['capacity']  # int
            opt = raw_data['l1_bound']  # float
            data.append((items, capacity, opt))
    else:  # offline
        raw_data = np.load(f'{dirname}/test{problem_config}_dataset.npz')
        data = []
        capacity = 150
        if int(problem_config) == 500:
            opt = 203.125
        elif int(problem_config) == 1000:
            opt = 405.172
        for instance in raw_data['demands']:
            items = instance  # np.ndarray (1D)
            data.append((items, capacity, opt))
    return data


if __name__ == '__main__':

    current_file_path = os.path.abspath(__file__)
    parent_directory = os.path.dirname(current_file_path)
    sys.path.append(parent_directory)

    path = sys.argv[1]
    exp_output_path = path.split('/')[0]
    problem_config = exp_output_path.split('_')[0]
    exp_tag = path.split('/')[-1][4:]
    test_problem_configs = sys.argv[2].split('+')
    k = int(sys.argv[3])
    k_ = k-1
    h = int(sys.argv[4])

    with open(f'{path}/population_generation_{k}.json') as file:
        pop = json.load(file)
    heuristic = pop[h-1]
    problem, code, operator = heuristic["problem"], heuristic["code"], heuristic["op"]

    file_path = f"{exp_output_path}/results/pops{exp_tag}/reduc_population_generation_{k}.json"
    with open(file_path, 'r') as file:
        reduc_pop = json.load(file)
    for reduction in reduc_pop:
        if reduction["problem"] == problem:
            if reduction['op'] == 'ri0':
                reduc_code = reduction["code"]
            else:  # backtrack previous generations to determine the actual reduction that is responsible for the heuristic
                while k_ >= 0:
                    with open(f'{path}/population_generation_{k_}.json') as file:
                        pop = json.load(file)
                    if heuristic not in pop:
                        reduc_code = reduction["code"]
                        break
                    else:
                        file_path = f"{exp_output_path}/results/pops{exp_tag}/reduc_population_generation_{k_}.json"
                        with open(file_path, 'r') as file:
                            rpop = json.load(file)
                        for reduc in rpop:
                            if reduc["problem"] == problem:
                                reduc_code = reduc["code"]
                    k_ -= 1
            break

    res_filename = f"{exp_output_path}/results/performance{exp_tag}-gen{k}_h{h}.txt"
    f = open(res_filename, 'w+')
    csv_filename = f"{exp_output_path}/results/performance{exp_tag}-gen{k}_h{h}_csv.txt"
    c = open(csv_filename, 'w+')

    c.write('|'.join(['Generation', 'Operator', 'Problem Size', 'Objective', 'Runtime']) + '\n')
    f.write(f'===========================\nGeneration #{k}\n')

    for test_problem_config in test_problem_configs:
        is_online = True if '_' in test_problem_config else False
        data = load_external_data('dataset_mcts', test_problem_config, is_online)
        f.write(f"problem size = {problem_config} -> {test_problem_config}\n")
        evaluator = Test(data, test_problem_config, is_online)

        time_start = time.time()
        objective = Value('d', np.inf)
        p = Process(target=evaluator.evaluate_new, args=(code, objective, reduc_code))
        p.start()
        p.join(99999)
        if p.is_alive():
            print('Timeout!')
            p.terminate()
            p.join()

        runtime = time.time()-time_start

        f.write(f"Heuristic #{h}, from {operator}):\nobjective = {objective.value:.5f}, time = {runtime:.1f} s\n")
        f.flush()
        f.write('\n')
        c.write('|'.join([str(k), operator, str(test_problem_config), str(objective.value), str(runtime)]) + '\n')
        c.flush()

    f.close()
    c.close()
    pd.read_csv(csv_filename, delimiter='|').to_csv(f'{csv_filename[:-8]}.csv', index=False)
    os.remove(csv_filename)
