''' Test a heuristic from a specific exp folder and generation
Usage: python [FILENAME].py [path] [test_problem_size(s)] [gen no.] [heuristic no.]
Note: this only applies for external test data from MCTS-AHD and reduc=True
Note: if plural, separate using '+'
Example: python [FILENAME].py 50_12/results/pops_mcts_evored_1 50+100+200 10 2
'''
import numpy as np
import pickle, json
import types
import warnings
from multiprocessing import Process, Value
import sys, os
import copy
import time
import pandas as pd
import numpy as np
import networkx as nx
from typing import Union


class Test():
    def __init__(self, data, problem_size: int):
        self.problem_size = problem_size
        self.n_instance = len(data)
        self._datasets = data

    def route_cost(self, coord_matrix, route):
        """
        Args:
            route (Numpy 1D array): first entry is 0 (depot), and last entry is the last visited node
        """
        cost = 0
        for j in range(len(route) - 1):
            cost += np.linalg.norm(coord_matrix[int(route[j])] - coord_matrix[int(route[j + 1])])
        cost += np.linalg.norm(coord_matrix[int(route[-1])] - coord_matrix[int(route[0])])  # distance from last visited node back to depot for good
        return cost

    def calObjectives(self, alg, reduction=None) -> float:

        dis = np.ones(self.n_instance)

        for i, (coord_matrix, distance_matrix, demands, capacity) in enumerate(self._datasets):

            if reduction is None:
                route = alg.find_route(coord_matrix, distance_matrix, demands, capacity)
            else:
                input_A = (coord_matrix, distance_matrix, demands, capacity)
                input_B = reduction.convert_input_A_to_B(*input_A)
                try:
                    solution_B = alg.solve_B(input_B)
                except TypeError as e:
                    solution_B = alg.solve_B(*input_B)
                routes = reduction.convert_solution_B_to_A(solution_B)
                route = np.concatenate([np.concatenate((np.zeros(1,), subroute)) for subroute in routes]).astype(int)

            LLM_dis = self.route_cost(coord_matrix, route)

            dis[i] = LLM_dis

        ave_dis = np.average(dis)
        # print("average dis: ",ave_dis)
        return -ave_dis


    def evaluate_new(self, code_string, objective, reduction=None):

        if reduction is not None:
        # Create a new module object for the reduction
            reduction_module = types.ModuleType("reduction_module")
            
            # Execute the code string in the new module's namespace
            exec(reduction, reduction_module.__dict__)

            # Add the module to sys.modules so it can be imported
            sys.modules[reduction_module.__name__] = reduction_module
        else:
            reduction_module = None

        # Create a new module object
        heuristic_module = types.ModuleType("heuristic_module")
        
        # Execute the code string in the new module's namespace
        exec(code_string, heuristic_module.__dict__)

        # Add the module to sys.modules so it can be imported
        sys.modules[heuristic_module.__name__] = heuristic_module

        objective.value = self.calObjectives(heuristic_module, reduction_module)


def load_external_data(dirname, problem_size):
    raw_data = np.load(f'{dirname}/test{problem_size}_dataset.npy')
    data = []
    n_instance = raw_data.shape[0]
    for i in range(n_instance):
        coordinates = raw_data[i,:,-2:]
        distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
        demands = raw_data[i,:,0]
        capacity = 50
        data.append((coordinates, distances, demands, capacity))
    return data


if __name__ == '__main__':

    current_file_path = os.path.abspath(__file__)
    parent_directory = os.path.dirname(current_file_path)
    sys.path.append(parent_directory)

    path = sys.argv[1]
    exp_output_path = path.split('/')[0]
    problem_size = exp_output_path.split('_')[0]
    exp_tag = path.split('/')[-1][4:]
    test_problem_sizes = sys.argv[2].split('+')
    k = int(sys.argv[3])
    k_ = k-1
    h = int(sys.argv[4])

    with open(f'{path}/population_generation_{k}.json') as file:
        pop = json.load(file)
    heuristic = pop[h-1]
    problem, code, operator = heuristic["problem"], heuristic["code"], heuristic["op"]

    file_path = f"{exp_output_path}/results/pops{exp_tag}/reduc_population_generation_{k}.json"
    with open(file_path, 'r') as file:
        reduc_pop = json.load(file)
    for reduction in reduc_pop:
        if reduction["problem"] == problem:
            if reduction['op'] == 'ri0':
                reduc_code = reduction["code"]
            else:  # backtrack previous generations to determine the actual reduction that is responsible for the heuristic
                while k_ >= 0:
                    with open(f'{path}/population_generation_{k_}.json') as file:
                        pop = json.load(file)
                    if heuristic not in pop:
                        reduc_code = reduction["code"]
                        break
                    else:
                        file_path = f"{exp_output_path}/results/pops{exp_tag}/reduc_population_generation_{k_}.json"
                        with open(file_path, 'r') as file:
                            rpop = json.load(file)
                        for reduc in rpop:
                            if reduc["problem"] == problem:
                                reduc_code = reduc["code"]
                    k_ -= 1
            break

    res_filename = f"{exp_output_path}/results/performance{exp_tag}-gen{k}_h{h}.txt"
    f = open(res_filename, 'w+')
    csv_filename = f"{exp_output_path}/results/performance{exp_tag}-gen{k}_h{h}_csv.txt"
    c = open(csv_filename, 'w+')

    c.write('|'.join(['Generation', 'Operator', 'Problem Size', 'Objective', 'Runtime']) + '\n')
    f.write(f'===========================\nGeneration #{k}\n')

    for test_problem_size in test_problem_sizes:
        test_problem_size = int(test_problem_size)
        data = load_external_data('dataset_mcts', test_problem_size)
        f.write(f"problem size = {problem_size} -> {test_problem_size}\n")
        evaluator = Test(data, test_problem_size)

        time_start = time.time()
        objective = Value('d', np.inf)
        p = Process(target=evaluator.evaluate_new, args=(code, objective, reduc_code))
        p.start()
        p.join(99999)
        if p.is_alive():
            print('Timeout!')
            p.terminate()
            p.join()

        runtime = time.time()-time_start

        f.write(f"Heuristic #{h}, from {operator}):\nobjective = {objective.value:.5f}, time = {runtime:.1f} s\n")
        f.flush()
        f.write('\n')
        c.write('|'.join([str(k), operator, str(test_problem_size), str(objective.value), str(runtime)]) + '\n')
        c.flush()

    f.close()
    c.close()
    pd.read_csv(csv_filename, delimiter='|').to_csv(f'{csv_filename[:-8]}.csv', index=False)
    os.remove(csv_filename)
