''' Test a heuristic from a specific exp folder and generation
Usage: python [FILENAME].py [path] [test_data_folder] [test_problem_size(s)] [gen no.] [heuristic no.] [(optional) trial_no]
Note: this only applies for external test data from MCTS-AHD and reduc=True
Note: if plural, separate using '+'
Example: python [FILENAME].py 50_12/results/pops_mcts_evored_1 dataset_mcts 50+100+200 10 2
Example: python [FILENAME].py 50_14/results/pops_mcts_evored_1 dataset_tsplib eil51.tsp+rat99.tsp+kroB100.tsp+kroC100.tsp+bier127.tsp+ch130.tsp+kroA150.tsp+ts225.tsp+pr226.tsp+pr264.tsp+pr299.tsp+lin318.tsp+fl417.tsp+pr439.tsp+d493.tsp+d657.tsp+u724.tsp+fl1577.tsp+d1655.tsp+u1817.tsp+rl1889.tsp 20 5
'''
import numpy as np
import pickle, json
import types
import warnings
from multiprocessing import Process, Value
import sys, os
import copy
import time
import pandas as pd
import numpy as np
import networkx as nx
from typing import Union
import re

# TSPLib
names = ["ts225", "rat99", "rl1889", "u1817", "d1655", "bier127", "lin318", "eil51", "d493", "kroB100", "kroC100", "ch130", "pr299", "fl417", "d657", "kroA150", "fl1577", "u724", "pr264", "pr226", "pr439"]
opt = {
    'ts225': 126643,
    'rat99': 1211,
    'rl1889': 316536,
    'u1817': 57201,
    'd1655': 62128,
    'bier127': 118282,
    'lin318': 42029,
    'eil51': 426,
    'd493': 35002,
    'kroB100': 22141,
    'kroC100': 20749,
    'ch130': 6110,
    'pr299': 48191,
    'fl417': 11861,
    'd657': 48912,
    'kroA150': 26524,
    'fl1577': 22249,
    'u724': 41910,
    'pr264': 49135,
    'pr226': 80369,
    'pr439': 107217
 }

class Test():
    def __init__(self, data, problem_size: int):
        self.problem_size = problem_size
        self.n_instance = len(data)
        self._datasets = data

    def route_cost(self, coord_matrix, route, problem_size):
        cost = 0
        for j in range(problem_size - 1):
            cost += np.linalg.norm(coord_matrix[int(route[j])] - coord_matrix[int(route[j + 1])])
        cost += np.linalg.norm(coord_matrix[int(route[-1])] - coord_matrix[int(route[0])])
        return cost

    def calObjectives(self, alg, reduction=None) -> float:

        dis = np.ones(self.n_instance)

        for i, (coord_matrix, distance_matrix) in enumerate(self._datasets):

            if reduction is None:
                route = alg.find_route(coord_matrix, distance_matrix)
            else:
                input_A = (coord_matrix, distance_matrix)
                input_B = reduction.convert_input_A_to_B(*input_A)
                solution_B = alg.solve_B(input_B)
                route = reduction.convert_solution_B_to_A(solution_B)

            LLM_dis = self.route_cost(coord_matrix, route, self.problem_size)

            dis[i] = LLM_dis

        ave_dis = np.average(dis)
        # print("average dis: ",ave_dis)
        return ave_dis


    def evaluate_new(self, code_string, objective, reduction=None):

        if reduction is not None:
        # Create a new module object for the reduction
            reduction_module = types.ModuleType("reduction_module")
            
            # Execute the code string in the new module's namespace
            exec(reduction, reduction_module.__dict__)

            # Add the module to sys.modules so it can be imported
            sys.modules[reduction_module.__name__] = reduction_module
        else:
            reduction_module = None

        # Create a new module object
        heuristic_module = types.ModuleType("heuristic_module")
        
        # Execute the code string in the new module's namespace
        exec(code_string, heuristic_module.__dict__)

        # Add the module to sys.modules so it can be imported
        sys.modules[heuristic_module.__name__] = heuristic_module

        objective.value = self.calObjectives(heuristic_module, reduction_module)


def load_external_data(dirname, problem_size):
    data = []
    if 'tsplib' in dirname:
        # files = sorted(os.listdir(dirname))
        # files.sort(key=lambda x: int(re.split(r'(\d+)', x)[1]))  # sort by ascending problem size
        # for file in files:
        for file in [problem_size]:  # only one at a time: ['eil51.tsp', 'rat99.tsp', 'kroB100.tsp', 'kroC100.tsp', 'bier127.tsp', 'ch130.tsp', 'kroA150.tsp', 'ts225.tsp', 'pr226.tsp', 'pr264.tsp', 'pr299.tsp', 'lin318.tsp', 'fl417.tsp', 'pr439.tsp', 'd493.tsp', 'd657.tsp', 'u724.tsp', 'fl1577.tsp', 'd1655.tsp', 'u1817.tsp', 'rl1889.tsp']
            with open(f'{dirname}/{file}') as f:
                lines = f.readlines()

            # Parse the data
            raw_data = lines[6:-1]
            raw_data = [x.strip().split() for x in raw_data]
            raw_data = [[float(x) for x in row[1:]] for row in raw_data]

            # Scale the coordinates to [0, 1]^2 to align with the training data
            coordinates = np.array(raw_data)
            scale = max(np.max(coordinates, axis=0) - np.min(coordinates, axis=0))
            coordinates = (coordinates - np.min(coordinates, axis=0)) / scale
            distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
            data.append((coordinates, distances))
            problem_size = int(re.split(r'(\d+)', problem_size)[1])
        return data, problem_size, scale
    else:
        raw_data = np.load(f'{dirname}/test{problem_size}_dataset.npy')
        n_instance = raw_data.shape[0]
        for i in range(n_instance):
            coordinates = raw_data[i,:,:]
            distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
            data.append((coordinates, distances))
            problem_size = int(problem_size)
        return data, problem_size


if __name__ == '__main__':

    current_file_path = os.path.abspath(__file__)
    parent_directory = os.path.dirname(current_file_path)
    sys.path.append(parent_directory)

    path = sys.argv[1]
    exp_output_path = path.split('/')[0]
    problem_size = exp_output_path.split('_')[0]
    exp_tag = path.split('/')[-1][4:]
    test_dataset_dir = sys.argv[2]
    if len(test_dataset_dir.split('_')) > 2:
        test_dataset_tag = f"-{test_dataset_dir.split('_', 1)[1]}"
    else:
        if 'tsplib' in test_dataset_dir:
            test_dataset_tag = '-tsplib'
        else:
            test_dataset_tag = ''
    test_problem_sizes = sys.argv[3].split('+')
    k = int(sys.argv[4])
    k_ = k-1
    h = int(sys.argv[5])
    trial_tag = f'_i{int(sys.argv[6]):02d}' if len(sys.argv) > 6 else ''

    with open(f'{path}/population_generation_{k}.json') as file:
        pop = json.load(file)
    heuristic = pop[h-1]
    problem, code, operator = heuristic["problem"], heuristic["code"], heuristic["op"]

    file_path = f"{exp_output_path}/results/pops{exp_tag}/reduc_population_generation_{k}.json"
    with open(file_path, 'r') as file:
        reduc_pop = json.load(file)
    for reduction in reduc_pop:
        if reduction["problem"] == problem:
            if reduction['op'] == 'ri0':
                reduc_code = reduction["code"]
            else:  # backtrack previous generations to determine the actual reduction that is responsible for the heuristic
                while k_ >= 0:
                    with open(f'{path}/population_generation_{k_}.json') as file:
                        pop = json.load(file)
                    if heuristic not in pop:
                        reduc_code = reduction["code"]
                        break
                    else:
                        file_path = f"{exp_output_path}/results/pops{exp_tag}/reduc_population_generation_{k_}.json"
                        with open(file_path, 'r') as file:
                            rpop = json.load(file)
                        for reduc in rpop:
                            if reduc["problem"] == problem:
                                reduc_code = reduc["code"]
                    k_ -= 1
            break

    res_filename = f"{exp_output_path}/results/performance{exp_tag}{test_dataset_tag}-gen{k}_h{h}{trial_tag}.txt"
    f = open(res_filename, 'w+')
    csv_filename = f"{exp_output_path}/results/performance{exp_tag}{test_dataset_tag}-gen{k}_h{h}{trial_tag}_csv.txt"
    c = open(csv_filename, 'w+')

    c.write('|'.join(['Generation', 'Operator', 'Problem Size', 'Objective', 'Runtime']) + '\n')
    f.write(f'===========================\nGeneration #{k}\n')

    for test_problem_size in test_problem_sizes:
        dataset = load_external_data(test_dataset_dir, test_problem_size)
        f.write(f"problem size = {problem_size} -> {test_problem_size}\n")
        evaluator = Test(dataset[0], dataset[1])

        time_start = time.time()
        objective = Value('d', np.inf)
        p = Process(target=evaluator.evaluate_new, args=(code, objective, reduc_code))
        p.start()
        p.join(99999)
        if p.is_alive():
            print('Timeout!')
            p.terminate()
            p.join()

        runtime = time.time()-time_start

        if len(dataset) > 2:  # TSPLib
            name = test_problem_size.split('.')[0]
            obj = (objective.value * dataset[2] - opt[name]) / opt[name] * 100  # return gap for TSPLib
        else:
            obj = objective.value
        f.write(f"Heuristic #{h}, from {operator}):\nobjective = {obj:.5f}, time = {runtime:.1f} s\n")
        f.flush()
        f.write('\n')
        c.write('|'.join([str(k), operator, str(test_problem_size), str(obj), str(runtime)]) + '\n')
        c.flush()

    f.close()
    c.close()
    pd.read_csv(csv_filename, delimiter='|').to_csv(f'{csv_filename[:-8]}.csv', index=False)
    os.remove(csv_filename)
