import os
import sys
import time
root_folder = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_folder)
sub_folder = os.path.join(root_folder, "src")
sys.path.append(sub_folder)
from src.meta_diffusion import (
    MetaDiffModel, MetaDiffEnv, GNNEncoder, 
    MISDecoder, MCutDecoder, MClDecoder
)

# settings
FINETUNE = False
DEVICE = "cuda:0"
TASK_POOL = ["MIS", "MCl", "MCut"]
TRAIN_TYPE = "SMALL"

if FINETUNE:
    assert len(TASK_POOL) == 1, "finetuning can only be applied on SINGLE specific task!"

TRAIN_FOLDER_DICT = {
    "SMALL": {
        "MIS": "train_dataset/mis",  # toy data for quick debugging
        "MCl": "train_dataset/mcl",
        "MCut": "train_dataset/mcut"
    },
}

VAL_FILE_DICT = {
    "SMALL": {
        "MIS": "test_dataset/mis/mis_rb-small_kamis-60s_20.090.txt",
        "MCl": "test_dataset/mcl/mcl_rb-small_gurobi_19.082.txt",
        "MCut": "test_dataset/mcut/mcut_ba-small_gurobi_60s_727.844.txt"  
    }
}

DECODER_DICT = {
    "MIS": MISDecoder(),
    "MCl": MClDecoder(),
    "MCut": MCutDecoder()
}

timestamp = time.time() 
formatted_time = time.strftime("%Y%m%d%H%M", time.localtime(timestamp))

if FINETUNE == False:
    WEIGHT_PATH_DICT = {
        "SMALL": f"weights/meta_pretrain/node_small/{formatted_time[-8:]}/" + "node_small-epoch-{}-{}.pt",
        "LARGE": f"weights/meta_pretrain/node_large/{formatted_time[-8:]}/" + "node_large-epoch-{}-{}.pt",
    }
    PLOT_FOLDER = f"train_plots/meta_pretrain/node/pretrain/node-{TRAIN_TYPE}-{formatted_time[-8:]}"
else:
    WEIGHT_PATH_DICT = {
        "SMALL": f"weights/meta_pretrain/finetune/{TASK_POOL[0]}-{TRAIN_TYPE}-{formatted_time[-8:]}/" + "epoch-{}-{}.pt",
        "LARGE": f"weights/meta_pretrain/finetune/{TASK_POOL[0]}-{TRAIN_TYPE}-{formatted_time[-8:]}/" + "epoch-{}-{}.pt",
    }
    PLOT_FOLDER = f"train_plots/meta_pretrain/node/finetune/{TASK_POOL[0]}-{TRAIN_TYPE}-{formatted_time[-8:]}"


if __name__ == "__main__":
    model=MetaDiffModel(
        grad_norm=True,
        env=MetaDiffEnv(
            task=TASK_POOL,
            mode="train", 
            train_data_size=1, # virtual
            val_data_size=16,
            train_batch_size=4,
            val_batch_size=1,
            device=DEVICE,
            train_folder=TRAIN_FOLDER_DICT[TRAIN_TYPE],
            val_path=VAL_FILE_DICT[TRAIN_TYPE]
        ),
        encoder=GNNEncoder(
            shared_block_layers=[1, 2],
            separate_block_layers=[2, 1], 
            hidden_dim=64,
            sparse=True, 
            task=TASK_POOL
        ),
        decoder=DECODER_DICT,
        train_outer_steps=20,
        inner_lr=5e-5,
        outer_lr=5e-4,
        train_inner_samples=4,
        save_n_epochs=5,
        save_path=WEIGHT_PATH_DICT[TRAIN_TYPE],
        plot_folder=PLOT_FOLDER,
        weight_path=None
    )

    model.model_train()