'''
Usage: python [FILENAME].py [problem_size(s)] [test_problem_size(s)] [external_data?] [reduc?] [n_trial] [(optional) n_trial_evored]
Note: if plural, separate using '+'
Note: 1 if using external test data from MCTS-AHD, 0 if using local test data
Note: 1 if performing reduction at all, 0 otherwise
Note: evored=True if provide the sixth sys argument (only applies if reduc=True)
Note: for n_trial and n_trial_evored, can use '-' to specify the range instead (e.g., 2-3 for trials #2 and #3, 2-2 for #2 only)
Example: python [FILENAME].py 50 50+100+200 1 0 1 1
'''
import numpy as np
import pickle, json
import types
import warnings
from multiprocessing import Process, Value
import sys, os
import copy
import time
import pandas as pd
import numpy as np
import networkx as nx
from typing import Union


class Test():
    def __init__(self, data, problem_size: int):
        self.problem_size = problem_size
        self.n_instance = len(data)
        self._datasets = data

    def route_cost(self, coord_matrix, route, problem_size):
        cost = 0
        for j in range(problem_size - 1):
            cost += np.linalg.norm(coord_matrix[int(route[j])] - coord_matrix[int(route[j + 1])])
        cost += np.linalg.norm(coord_matrix[int(route[-1])] - coord_matrix[int(route[0])])
        return cost

    def calObjectives(self, alg, reduction=None) -> float:

        dis = np.ones(self.n_instance)

        for i, (coord_matrix, distance_matrix) in enumerate(self._datasets):

            if reduction is None:
                route = alg.find_route(coord_matrix, distance_matrix)
            else:
                input_A = (coord_matrix, distance_matrix)
                input_B = reduction.convert_input_A_to_B(*input_A)
                solution_B = alg.solve_B(input_B)
                route = reduction.convert_solution_B_to_A(solution_B)

            LLM_dis = self.route_cost(coord_matrix, route, self.problem_size)

            dis[i] = LLM_dis

        ave_dis = np.average(dis)
        # print("average dis: ",ave_dis)
        return -ave_dis


    def evaluate_new(self, code_string, objective, reduction=None):

        if reduction is not None:
        # Create a new module object for the reduction
            reduction_module = types.ModuleType("reduction_module")
            
            # Execute the code string in the new module's namespace
            exec(reduction, reduction_module.__dict__)

            # Add the module to sys.modules so it can be imported
            sys.modules[reduction_module.__name__] = reduction_module
        else:
            reduction_module = None

        # Create a new module object
        heuristic_module = types.ModuleType("heuristic_module")
        
        # Execute the code string in the new module's namespace
        exec(code_string, heuristic_module.__dict__)

        # Add the module to sys.modules so it can be imported
        sys.modules[heuristic_module.__name__] = heuristic_module

        objective.value = self.calObjectives(heuristic_module, reduction_module)


def load_external_data(dirname, problem_size):
    raw_data = np.load(f'{dirname}/test{problem_size}_dataset.npy')
    data = []
    n_instance = raw_data.shape[0]
    for i in range(n_instance):
        coordinates = raw_data[i,:,:]
        distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
        data.append((coordinates, distances))
    return data


if __name__ == '__main__':

    current_file_path = os.path.abspath(__file__)
    parent_directory = os.path.dirname(current_file_path)
    sys.path.append(parent_directory)
    # problem_sizes = [50,]
    problem_sizes = sys.argv[1].split('+')
    test_problem_sizes = sys.argv[2].split('+')

    reduc = int(sys.argv[4])
    n_trial = sys.argv[5].split('-')
    if len(n_trial) > 1:
        start, end = int(n_trial[0])-1, int(n_trial[1])
    else:
        start, end = 0, int(n_trial[0])
    reduc_evol = bool(reduc) and (len(sys.argv) > 6)
    if reduc_evol:
        n_trial_evored = sys.argv[6].split('-')
        if len(n_trial_evored) > 1:
            start_evored, end_evored = int(n_trial_evored[0])-1, int(n_trial_evored[1])
        else:
            start_evored, end_evored = 0, int(n_trial_evored[0])
        evored_tags = [f'_evored_{j+1}' for j in range(start_evored, end_evored)]
    else:
        evored_tags = ['']

    compare_last_only = False  # only evaluate last generation, otherwise evaluate the best individual at each generation
    # last_gen = int(input('Input last generation to look at (int): '))
    last_gen = 10
    # top_n = int(input('Input size of top individuals in the pop to evaluate (int, max is paras.ec_pop_size): '))
    top_n = 1
    timeout = 10000


    external_data = int(sys.argv[3])
    if external_data == 1:
        data_tag = '_mcts'
    else:
        # Load test data from local files
        with open('all_data_test.pkl', 'rb') as f:
            all_data = pickle.load(f)
        data_tag = ''

    for problem_size in problem_sizes:
        problem_size = int(problem_size)
        for j in range(start, end):
            if reduc:
                exp_output_path = f"{problem_size}_{j+1}"
            else:
                exp_output_path = f"{problem_size}_{j+1}_raw"

            for evored in evored_tags:
                res_filename = f"{exp_output_path}/results/top_{top_n}_performance{data_tag}{evored}.txt"
                f = open(res_filename, 'w+')
                csv_filename = f"{exp_output_path}/results/top_{top_n}_performance{data_tag}{evored}_csv.txt"
                c = open(csv_filename, 'w+')

                if compare_last_only:
                    f.write(f'Generation #{last_gen}/{last_gen}\n')
                    c.write('|'.join(['Heuristic #', 'Operator', 'Problem Size', 'Objective', 'Runtime']) + '\n')
                    file_path = f"{exp_output_path}/results/pops{data_tag}{evored}/population_generation_{last_gen}.json"
                    # Open the file and load the JSON data
                    with open(file_path, 'r') as file:
                        pop = json.load(file)
                    pop = sorted(pop, key=lambda x: x['objective'], reverse=True)  # needed when doing multi-obj
                    if reduc:
                        file_path = f"{exp_output_path}/results/pops{data_tag}{evored}/reduc_population_generation_{last_gen}.json"
                        with open(file_path, 'r') as file:
                            reduc_pop = json.load(file)

                    # for test_problem_size, data in all_data.items():
                    for test_problem_size in test_problem_sizes:
                        test_problem_size = int(test_problem_size)
                        if external_data == 1:
                            data = load_external_data('dataset_mcts', test_problem_size)
                        else:  # 0 i.e., local data
                            data = all_data[test_problem_size]
                        f.write(f"problem size = {problem_size} -> {test_problem_size}\n")
                        evaluator = Test(data, test_problem_size)

                        # Access the ith item from the JSON data
                        i, i_true = 0, 0
                        while i < top_n:
                            code = pop[i]["code"] if pop else None
                            if reduc:
                                for reduction in reduc_pop:
                                    if reduction["problem"] == pop[i]["problem"]:
                                        reduc_code = reduction["code"]
                                        break
                            else:
                                reduc_code = None
                            operator = pop[i]["op"]

                            time_start = time.time()
                            objective = Value('d', np.inf)
                            p = Process(target=evaluator.evaluate_new, args=(code, objective, reduc_code))
                            p.start()
                            p.join(timeout)
                            if p.is_alive():
                                print('Timeout!')
                                p.terminate()
                                p.join()
                            if np.isinf(objective.value):
                                del pop[i]
                                i_true += 1
                                continue
                            runtime = time.time()-time_start

                            f.write(f"Heuristic #{i+1} (#{i_true+1}/{last_gen}, from {operator}):\nobjective = {objective.value:.5f}, time = {runtime:.1f} s\n")
                            f.flush()
                            c.write('|'.join([str(i+1), operator, str(test_problem_size), str(objective.value), str(runtime)]) + '\n')
                            c.flush()
                            i += 1
                            i_true += 1

                else:
                    obj_per_gen, time_per_gen = np.zeros(last_gen+1), np.zeros(last_gen+1)
                    c.write('|'.join(['Generation', 'Operator', 'Problem Size', 'Objective', 'Runtime']) + '\n')
                    # for k in range(last_gen+1):
                    for k in range(last_gen, last_gen+1):
                        f.write(f'===========================\nGeneration #{k}/{last_gen}\n')
                        file_path = f"{exp_output_path}/results/pops{data_tag}{evored}/population_generation_{k}.json"
                        # Open the file and load the JSON data
                        with open(file_path, 'r') as file:
                            pop = json.load(file)
                        pop = sorted(pop, key=lambda x: x['objective'], reverse=True)  # needed when doing multi-obj
                        if reduc:
                            file_path = f"{exp_output_path}/results/pops{data_tag}{evored}/reduc_population_generation_{last_gen}.json"
                            with open(file_path, 'r') as file:
                                reduc_pop = json.load(file)

                        # for test_problem_size, data in all_data.items():
                        for test_problem_size in test_problem_sizes:
                            test_problem_size = int(test_problem_size)
                            if external_data == 1:
                                data = load_external_data('dataset_mcts', test_problem_size)
                            else:  # 0 i.e., local data
                                data = all_data[test_problem_size]
                            f.write(f"problem size = {problem_size} -> {test_problem_size}\n")
                            evaluator = Test(data, test_problem_size)

                            top_objs, top_times = np.zeros(top_n), np.zeros(top_n)
                            operators = ''
                            i, i_true = 0, 0
                            while i < top_n:
                                code = pop[i]["code"] if pop else None
                                if reduc:
                                    for reduction in reduc_pop:
                                        if reduction["problem"] == pop[i]["problem"]:
                                            reduc_code = reduction["code"]
                                            break
                                else:
                                    reduc_code = None
                                operator = pop[i]["op"]
                                operators = operators+'->'+operator if operators else operators+operator

                                time_start = time.time()
                                objective = Value('d', np.inf)
                                p = Process(target=evaluator.evaluate_new, args=(code, objective, reduc_code))
                                p.start()
                                p.join(timeout)
                                if p.is_alive():
                                    print('Timeout!')
                                    p.terminate()
                                    p.join()
                                if np.isinf(objective.value):
                                    del pop[i]
                                    i_true += 1
                                    continue
                                runtime = time.time()-time_start

                                top_objs[i], top_times[i] = objective.value, runtime

                                f.write(f"Heuristic #{i+1} (#{i_true+1}/{last_gen}, from {operator}):\nobjective = {objective.value:.5f}, time = {runtime:.1f} s\n")
                                f.flush()
                                i += 1
                                i_true += 1

                            f.write('\n')
                            obj_per_gen[k], time_per_gen[k] = np.mean(top_objs), np.mean(top_times)
                            c.write('|'.join([str(k), operators, str(test_problem_size), str(obj_per_gen[k]), str(time_per_gen[k])]) + '\n')
                            c.flush()

                f.close()
                c.close()
                pd.read_csv(csv_filename, delimiter='|').to_csv(f'{csv_filename[:-8]}.csv', index=False)
                os.remove(csv_filename)

