import os
import sys
root_folder = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_folder)
sub_folder = os.path.join(root_folder, "src")
sys.path.append(sub_folder)
from src.meta_diffusion import *
from typing import Union

TEST_TYPE = ("MIS", "hk", "small")
DEVICE = "cuda:0"

WEIGHT_PATH_DICT = {
    ("MIS", "ba", "small"): "weights/mis_ba_small.pt",
    ("MIS", "hk", "small"): "weights/mis_hk_small.pt",
    ("MIS", "ws", "small"): "weights/mis_ws_small.pt",

    ("MCl", "ba", "small"): "weights/mcl_ba_small.pt",
    ("MCl", "hk", "small"): "weights/mcl_hk_small.pt",
    ("MCl", "ws", "small"): "weights/mcl_ws_small.pt",

    ("MCut", "rb", "small"): "weights/mcut_rb_small.pt",
    ("MCut", "ws", "small"): "weights/mcut_ws_small.pt",
    ("MCut", "hk", "small"): "weights/mcut_hk_small.pt",
}

QUERY_FILE_DICT = {
    ("MIS", "ba", "small"): "query_dataset/mis/mis_ba_small_query.txt",
    ("MIS", "hk", "small"): "query_dataset/mis/mis_hk_small_query.txt",
    ("MIS", "ws", "small"): "query_dataset/mis/mis_ws_small_query.txt",

    ("MCl", "ba", "small"): "query_dataset/mcl/mcl_ba-small_query.txt",
    ("MCl", "hk", "small"): "query_dataset/mcl/mcl_hk-small_query.txt",
    ("MCl", "ws", "small"): "query_dataset/mcl/mcl_ws-small_query.txt",

    ("MCut", "rb", "small"): "query_dataset/mcut/mcut_rb-small_query.txt",
    ("MCut", "hk", "small"): "query_dataset/mcut/mcut_hk-small_query.txt",
    ("MCut", "ws", "small"): "query_dataset/mcut/mcut_ws-small_query.txt",
}

SOLVER_DICT = {
    "MIS": MetaDiffMISSolver,
    "MCl": MetaDiffMClSolver,
    "MCut": MetaDiffMCutSolver,
}

DECODER_DICT = {
    "MIS": MISDecoder,
    "MCl": MClDecoder,
    "MCut": MCutDecoder,
}

if __name__ == "__main__":
    task, data_type, scale = TEST_TYPE
    Decoder = DECODER_DICT[task]
    Solver = SOLVER_DICT[task]
    weight_path = WEIGHT_PATH_DICT[TEST_TYPE]
    data_path = QUERY_FILE_DICT[TEST_TYPE]
    print(f"Testing benchmark: {TEST_TYPE}")

    solver: Union[MetaDiffMISSolver, MetaDiffMClSolver, MetaDiffMCutSolver] = \
    Solver(
        model=MetaDiffModel(
            env=MetaDiffEnv(
                task=task, mode="solve", sparse_factor=1, device=DEVICE,
            ),
            encoder=GNNEncoder(
                sparse=True,
                shared_block_layers=[1, 2],
                separate_block_layers=[2, 1],
                hidden_dim=64
            ),
            decoder=Decoder(decoding_type="greedy"),
            weight_path=weight_path,
        ),
    )
    solver.from_txt(data_path, ref=True, show_time=True)
    solver.solve(show_time=True)
    results = solver.evaluate(calculate_gap=True)
    print(results)