import argparse
import os

from mm_cvrp.utils import Loss
from mm_cvrp.utils import get_src_vector
from utils.git_util import get_git_commit_hash

import wandb


def main(args: argparse.Namespace) -> None:
    # unstaged_changes = check_no_uncommitted_changes()
    # if len(unstaged_changes) > 0:
    #     for line in unstaged_changes:
    #         print(line)
    #     raise ValueError("unstaged changes are detected!!")

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    import torch
    from mm_cvrp.trainer import Trainer

    dev = "cuda" if torch.cuda.is_available() else "cpu"

    # Config your wandb
    wandb.login(key="")  # Login with wandb account key
    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="mm_cvrp",
        # set resume configuration
        id=args.key,
        resume="allow",
        # track hyperparameters and run metadata
        config={
            "stage": "train",
            "optim": "LAX",
            "n_node": args.n_nodes,
            "n_agent": args.n_agent,
            "epochs": args.iteration,
            "lr_p": args.lr_p,
            "lr_s": args.lr_s,
            "commit_hash": get_git_commit_hash(),
        },
    )
    print("run id:{}".format(id))
    src_vector = torch.tensor(get_src_vector(args.n_nodes - 1, args.n_agent), device=dev)

    if args.capacity * args.n_agent < args.n_nodes:
        raise ValueError("capacity is too small")

    trainer = Trainer(
        n_nodes=args.n_nodes,
        n_agent=args.n_agent,
        batch_size=args.batch_size,
        in_chnl=2,
        hid_chnl=64,
        key_size_embd=32,
        key_size_policy=128,
        val_size=16,
        clipping=10,
        lr_p=args.lr_p,
        lr_s=args.lr_s,
        dev=dev,
        src_vector=src_vector,
        output=args.output,
        loss=args.loss,
        capacity=args.capacity,
        disable_softmax=args.disable_softmax,
        train_folder=args.train_folder,
        validation_folder=args.validation_folder,
        augmentation=not args.disable_augmentation,
    )
    trainer(args.iteration)

    wandb.finish()


if __name__ == "__main__":
    # torch.use_deterministic_algorithms(True)

    parser = argparse.ArgumentParser()
    parser.add_argument("--n-agent", type=int, default=5, help="number of agents")
    parser.add_argument("--n-nodes", type=int, default=400, help="number of nodes")
    parser.add_argument("--batch-size", type=int, default=512, help="batchsize")
    parser.add_argument("--iteration", type=int, default=25, help="seed")
    parser.add_argument("--capacity", type=int, default=25, help="capacity")
    parser.add_argument("--lr-p", type=float, default=1e-4, help="learning rate of policy network")
    parser.add_argument("--lr-s", type=float, default=1e-3, help="learning rate of surrogate network")
    parser.add_argument("--seed", type=int, default=86, help="seed")
    parser.add_argument("--key", type=str, required=True, help="key for wandb")
    parser.add_argument("--gpu", type=str, help="gpu id")
    parser.add_argument("--disable-wandb", action="store_true", help="")
    parser.add_argument("--disable-softmax", action="store_true", help="")
    parser.add_argument("--disable-augmentation", action="store_true", help="")
    parser.add_argument("--output", type=str, default="model", help="output-folder")
    parser.add_argument("--loss", choices=Loss.getValues(), required=True, help="output-folder")
    parser.add_argument("--train-folder", type=str, required=True, help="train folder")
    parser.add_argument("--validation-folder", type=str, required=True, help="validation folder")

    args = parser.parse_args()
    main(args)
