import os
import sys
import time
root_folder = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_folder)
sub_folder = os.path.join(root_folder, "src")
sys.path.append(sub_folder)
from src.meta_diffusion import (
    MetaDiffModel, MetaDiffEnv, GNNEncoder,
    TSPDecoder, ATSPDecoder
)

FINETUNE = False     # modify here
DEVICE = "cuda:0"   # modify here

TRAIN_FOLDER_DICT = {
    50: {
        "TSP": "train_dataset/tsp",
        "ATSP": "train_dataset/atsp"
    },
}

VAL_FILE_DICT = {
    50: {
        "TSP": "test_dataset/tsp/tsp50_concorde_5.68759.txt",
        "ATSP": "test_dataset/atsp/atsp50_uniform_lkh_1000_1.55448.txt"
    },
}

timestamp = time.time() 
formatted_time = time.strftime("%Y%m%d%H%M", time.localtime(timestamp))

DECODER_DICT = {
    "TSP": TSPDecoder(),
    "ATSP": ATSPDecoder()
}

if FINETUNE == False:
    TASK_POOL = ["TSP", "ATSP"] # modify here
    TRAIN_TYPE = 50        # modify here
    WEIGHT_PATH_DICT = {
        50: f"weights/meta_pretrain/edge_small/{TRAIN_TYPE}-{formatted_time[-8:]}/" + "edge_50-epoch-{}-{}.pt",
    }
    plot_folder = f"train_plots/meta_pretrain/edge/pretrain/edge-{TRAIN_TYPE}-{formatted_time[-8:]}"
    train_folder = TRAIN_FOLDER_DICT[TRAIN_TYPE]
    train_path = None
    val_path = VAL_FILE_DICT[TRAIN_TYPE]
    save_path = WEIGHT_PATH_DICT[TRAIN_TYPE]
    sparse_factor = -1 # if TRAIN_TYPE <= 100 else 50
else:
    raise NotImplementedError() # in this version meta pre-training only

if __name__ == "__main__":
    model=MetaDiffModel(
        env=MetaDiffEnv(
            task=TASK_POOL, 
            mode="train", 
            train_data_size=1280,
            val_data_size=1,
            train_batch_size=32,
            val_batch_size=1,
            device=DEVICE,
            train_folder=train_folder,
            train_path=train_path,
            val_path=val_path,
            sparse_factor=sparse_factor
        ),
        encoder=GNNEncoder(
            shared_block_layers=[1, 2], 
            separate_block_layers=[2, 1], 
            hidden_dim=64, 
            sparse=(sparse_factor>0), 
            task=["TSP", "ATSP"]
        ),
        decoder=DECODER_DICT,
        train_outer_steps=50, 
        train_inner_steps=1,
        inner_lr=5e-5,
        outer_lr=2e-4, 
        save_n_epochs=5000,
        train_inner_samples=4,
        save_path=save_path,
        plot_folder=plot_folder,
        weight_path=None
    )

    model.model_train()