import torch
import argparse
import torch.nn as nn
from tqdm import tqdm as tqdm_load
from pancreas_utils import *
from test_util import *
from dataloaders import get_ema_model_and_dataloader, create_Vnet

def parse_args():
    parser = argparse.ArgumentParser(description='Test model on pancreas dataset')
    parser.add_argument('--ckpt', type=str, required=True, help='Path to checkpoint file')
    parser.add_argument('--output', type=str, required=True, help='Path to output metrics file')
    parser.add_argument('--data_root', type=str, default='/path/to/data/pancreas/data_split', help='Path to data root')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--label_percent', type=int, default=20, help='Label percentage')
    parser.add_argument('--s_xy', type=int, default=16, help='Stride in xy plane')
    parser.add_argument('--s_z', type=int, default=4, help='Stride in z direction')
    return parser.parse_args()

def test_model(net, test_loader, s_xy=16, s_z=4):
    print('Testing model...')
    avg_metric, m_list = test_calculate_metric(net, test_loader.dataset, s_xy=s_xy, s_z=s_z)
    return avg_metric, m_list

def main():
    args = parse_args()
    
    # Set random seed for reproducibility
    seed_test = 2020
    seed_reproducer(seed=seed_test)
    
    # Create model
    net = create_Vnet()
    
    # Load checkpoint
    print(f"Loading checkpoint from {args.ckpt}")
    ckpt = torch.load(args.ckpt)
    net.load_state_dict(ckpt['net'])
    
    # Get data loader
    split_name = 'pancreas'
    _, _, _, _, _, _, _, test_loader = get_ema_model_and_dataloader(
        args.data_root, split_name, args.batch_size, args.lr, labelp=args.label_percent
    )
    
    # Test model
    avg_metric, m_list = test_model(net, test_loader, s_xy=args.s_xy, s_z=args.s_z)
    print(f"Average metrics: {avg_metric}")
    

if __name__ == '__main__':
    main()
