#!/usr/bin/env python3
"""
Evaluation script for point classification downstream task.
"""
import os
import sys
import argparse
import gc

import torch

# make sure your FM4NPP modules can be imported
sys.path.append('../..')

from fm4npp.utils import YParams
from track_finding_trainer import DownstreamTrainer

def main():
    parser = argparse.ArgumentParser(
        description="Evaluation script for point classification downstream task"
    )
    parser.add_argument(
        "--yaml_config",
        type=str,
        required=True,
        help="Path to the YAML config file",
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Model config name (e.g. d9_m64_k30_p20)",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default="",
        help="Path to the trained checkpoint (optional; overrides default)",
    )
    parser.add_argument(
        "--run_num",
        type=str,
        default="0",
        help="Run number / seed identifier",
    )
    parser.add_argument(
        "--root_dir",
        type=str,
        default="",
        help="Root directory to store evaluation outputs",
    )
    parser.add_argument(
        "--eventnumber",
        type=int,
        default=70000,
        help="Number of events (samples) to evaluate",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=1,
        help="Batch size for evaluation",
    )
    parser.add_argument(
        "--usepretrain",
        dest="usepretrain",
        action="store_true",
        help="enable using the pretrained model (default)",
    )
    parser.add_argument(
        "--no-pretrain",
        dest="usepretrain",
        action="store_false",
        help="disable pretrained model",
    )
    parser.set_defaults(usepretrain=True)
    parser.add_argument("--global_log_dir", default='globallogs', type=str, help="Global dir to store logging only")
    args = parser.parse_args()

    # Default mapping from config to checkpoint if not provided via --checkpoint
    model2ckpt = {

    }

    # Determine which checkpoint to use
    

    # Prepare hyperparameters
    params = YParams(os.path.abspath(args.yaml_config), args.config)
    params.limit_data = True
    params.limit_size = args.eventnumber
    params.batch_size = args.eval_batch_size
    params.valid_batch_size = args.eval_batch_size
    params.pretrained_ckpt = model2ckpt[args.config]
    checkpoint_base_name = f"{args.config}_nerf_tracking_head_d{params.limit_size}_{args.run_num}"
    log_base_name = f"{args.config}_eval_tracking_head_d{params.limit_size}_{args.run_num}"
    if args.usepretrain:
        params.log_file_name = log_base_name + ".log"
        checkpoint_name = checkpoint_base_name + "_checkpoint.pth"
    else:
        params.log_file_name = log_base_name + "_nopretrain.log"
        checkpoint_name = checkpoint_base_name + "_nopretrain_checkpoint.pth"
    params.num_embedder_layers = 0
    params.data_root_test = ""
    checkpoint_base_dir = ""
    checkpoint_path = checkpoint_base_dir + checkpoint_name

    params.loss_matched_ce_weight = 0.5
    params.loss_unmatched_ce_weight = 0.1
    params.loss_dice_weight = 1
    params.loss_focal_weight = 30
    params.num_embedder_layers = 0
    params.return_reg_test = True

    # Ensure output directory exists
    log_dir = args.root_dir
    os.makedirs(log_dir, exist_ok=True)
    logfile = os.path.join(log_dir, params.log_file_name)

    # Launch and run inference
    trainer = DownstreamTrainer(params, args)
    trainer.launch()
    trainer.inference(
        checkpoint_path=checkpoint_path,
        pretrain=args.usepretrain,                # evaluation uses downstream checkpoint
        logfile=logfile
    )
    trainer.cleanup()

    # Free GPU memory
    torch.cuda.empty_cache()
    gc.collect()

if __name__ == "__main__":
    main()
