#!/usr/bin/env python3
"""
Evaluation script for point classification downstream task.
"""
import os
import sys
import argparse
import gc

import torch

# make sure the FM4NPP modules can be imported
sys.path.append('../..')


from fm4npp.utils import YParams
from point_classification_trainer import DownstreamTrainer

def main():
    parser = argparse.ArgumentParser(
        description="Evaluation script for point classification downstream task"
    )
    parser.add_argument(
        "--yaml_config",
        type=str,
        required=True,
        help="Path to the YAML config file",
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Model config name (e.g. d9_m64_k30_p20)",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default="",
        help="Path to the trained checkpoint (optional; overrides default)",
    )
    parser.add_argument(
        "--run_num",
        type=str,
        default="0",
        help="Run number / seed identifier",
    )
    parser.add_argument(
        "--root_dir",
        type=str,
        default="",
        help="Root directory to store evaluation outputs",
    )
    parser.add_argument(
        "--eventnumber",
        type=int,
        default=70000,
        help="Number of events (samples) to evaluate",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=1,
        help="Batch size for evaluation",
    )
    parser.add_argument("--global_log_dir", default='globallogs', type=str, help="Global dir to store logging only")
    args = parser.parse_args()

    # Default mapping from config to checkpoint if not provided via --checkpoint
    model2ckpt = {

    }

    # Determine which checkpoint to use
    

    # Prepare hyperparameters
    params = YParams(os.path.abspath(args.yaml_config), args.config)
    params.limit_data = True
    params.limit_size = args.eventnumber
    params.batch_size = args.eval_batch_size
    params.valid_batch_size = args.eval_batch_size
    params.pretrained_ckpt = model2ckpt[args.config]
    params.log_file_name = f"{args.config}_eval_{params.task}_d{params.limit_size}_{args.run_num}.log"
    params.num_embedder_layers = 0
    params.data_root_test = ""
    checkpoint_name = f"{args.config}_nerf_{params.task}_d{params.limit_size}_{args.run_num}_checkpoint.pth"
    checkpoint_base_dir = ""
    checkpoint_path = checkpoint_base_dir + checkpoint_name

    # Ensure output directory exists
    log_dir = args.root_dir
    os.makedirs(log_dir, exist_ok=True)
    logfile = os.path.join(log_dir, params.log_file_name)

    # Launch and run inference
    trainer = DownstreamTrainer(params, args)
    trainer.launch()
    trainer.inference(
        checkpoint_path=checkpoint_path,
        pretrain=True,                # evaluation uses downstream checkpoint
        logfile=logfile
    )
    trainer.cleanup()

    # Free GPU memory
    torch.cuda.empty_cache()
    gc.collect()

if __name__ == "__main__":
    main()
