import os
import sys
root_folder = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_folder)
sub_folder = os.path.join(root_folder, "src")
sys.path.append(sub_folder)
from src.meta_diffusion import *
from typing import Union

TEST_TYPE = ("TSP", "cluster", 50)
DEVICE = "cuda:0"

QUERY_FILE_DICT = {
    ("TSP", "gaussian", 50): "query_dataset/tsp/tsp50_gaussian_query.txt",
    ("TSP", "gaussian", 100): "query_dataset/tsp/tsp100_gaussian_query.txt",
    ("TSP", "cluster", 50): "query_dataset/tsp/tsp50_cluster_query.txt",
    ("TSP", "cluster", 100): "query_dataset/tsp/tsp100_cluster_query.txt",
}

WEIGHT_PATH_DICT = {
    ("TSP", "gaussian", 50): "weights/tsp_gaussian_50.pt", 
    ("TSP", "gaussian", 100): "weights/tsp_gaussian_100.pt", 
    ("TSP", "cluster", 50): "weights/tsp_cluster_50.pt", 
    ("TSP", "cluster", 100): "weights/tsp_cluster_100.pt", 
}

SOLVER_DICT = {
    "TSP": MetaDiffTSPSolver,
    "ATSP": MetaDiffATSPSolver,
}

DECODER_DICT = {
    "TSP": TSPDecoder,
    "ATSP": ATSPDecoder,
}

# main
if __name__ == "__main__":
    task, _, nodes_num = TEST_TYPE
    Decoder = DECODER_DICT[task]
    Solver = SOLVER_DICT[task]
    weight_path = WEIGHT_PATH_DICT[TEST_TYPE] 
    data_path = QUERY_FILE_DICT[TEST_TYPE]
    sparse_factor = -1
    print(f"Testing benchmark: {TEST_TYPE}")

    solver: Union[MetaDiffTSPSolver, MetaDiffATSPSolver] = Solver (
        model=MetaDiffModel(
            inference_steps=1,
            env=MetaDiffEnv(
                task=task, mode="solve", sparse_factor=sparse_factor, device=DEVICE,
            ),
            encoder=GNNEncoder(
                task=["TSP", "ATSP"],
                sparse=sparse_factor>0,
                shared_block_layers=[1, 2],
                separate_block_layers=[2, 1],
                hidden_dim=64
            ),
            decoder=Decoder("greedy"),
            weight_path=weight_path
        )
    )
    solver.from_txt(data_path, ref=True, show_time=True, normalize=True)
    solver.solve(show_time=True)
    print(solver.evaluate(calculate_gap=True))