#!/usr/bin/env python3
"""
Downstream tracking training script.
"""
import os
import sys
import argparse
import gc

import torch

# ensure your fm4npp modules can be found
sys.path.append('../..')

from fm4npp.utils import YParams
from point_classification_trainer import DownstreamTrainer

def main():
    parser = argparse.ArgumentParser(description="Downstream tracking training script")
    parser.add_argument("--yaml_config", default='', type=str, help="Path to YAML config file")
    parser.add_argument("--config", default='', type=str, help="Model config name")
    parser.add_argument("--run_num", default='0', type=str, help="Sub run number")
    parser.add_argument("--root_dir", default='', type=str, help="Root dir to store results")
    parser.add_argument("--global_log_dir", default='globallogs', type=str, help="Global dir to store logging only")
    parser.add_argument("--eventnumber", default=70000, type=int, help="downstream training event number")
    parser.add_argument("--usepretrain", default=True, type=str, help="use pretrain model")
    parser.add_argument("--train_batch_size", default=32, type=int, help="train batch size")
    args = parser.parse_args()

    # Mapping from model name to log file and checkpoint paths
    model2log = {

    }

    model2ckpt = {

    }

    # Example overrides for running in a notebook; uncomment to hardcode
    # args.config = "d9_m96_k5_p20"
    # args.run_num = "2"

    # Initialize parameters
    params = YParams(os.path.abspath(args.yaml_config), args.config)
    params.continue_from_best = True
    params.batch_size = int(args.train_batch_size)
    params.limit_data = True
    params.limit_size = int(args.eventnumber)
    params.valid_batch_size = 1
    params.pretrained_ckpt = model2ckpt[args.config]
    params.log_file_name = f"{args.config}_nerf_{params.task}_d{params.limit_size}_{args.run_num}.log"
    params.num_embedder_layers = 0

    # Launch and train
    trainer = DownstreamTrainer(params, args)
    trainer.launch()
    checkpoint_path = None
    trainer.train(pretrain=True, train_from_checkpoint=False, checkpoint_path=checkpoint_path)

    # Cleanup
    trainer.cleanup()
    torch.cuda.empty_cache()
    gc.collect()

if __name__ == "__main__":
    main()
