"""
ScanObjectNN Inference Script

Note: This script requires the full model (encoder + decoder) from the checkpoint.
Only the decoder code is provided in this package for review purposes.
The checkpoint contains the complete pretrained model.

Usage:
    python tools/test_scanobjectnn_simple.py \
        --checkpoint checkpoints/scanobjectnn_morton_best.pth \
        --data_root /path/to/ScanObjectNN/h5_files \
        --output results/scanobjectnn
"""

import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import numpy as np
import argparse
import json
import os
from tqdm import tqdm
from torch.utils.data import DataLoader

from data_utils import ScanObjectNNDataset, collate_fn_scanobjectnn

SCANOBJECTNN_CLASSES = [
    'bag', 'bin', 'box', 'cabinet', 'chair',
    'desk', 'display', 'door', 'shelf', 'table',
    'bed', 'pillow', 'sink', 'sofa', 'toilet'
]


def parse_args():
    parser = argparse.ArgumentParser(description='ScanObjectNN Inference')
    parser.add_argument('--checkpoint', type=str, required=True,
                        help='Checkpoint path')
    parser.add_argument('--data_root', type=str, required=True,
                        help='ScanObjectNN data root (h5_files directory)')
    parser.add_argument('--variant', type=str, default='PB_T50_RS',
                        help='Dataset variant')
    parser.add_argument('--output', type=str, default='results/scanobjectnn',
                        help='Output directory')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='Batch size')
    parser.add_argument('--num_points', type=int, default=2048,
                        help='Number of points per object')
    parser.add_argument('--gpu', type=int, default=0,
                        help='GPU ID')
    return parser.parse_args()


def load_model(checkpoint_path, device):
    """Load complete model from checkpoint"""
    print(f"Loading checkpoint from {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    # The checkpoint should contain the complete model
    # (encoder + decoder) for classification task
    if 'model' in checkpoint:
        model = checkpoint['model']
    else:
        raise ValueError(
            "Checkpoint does not contain complete model. "
            "For ScanObjectNN classification, the full model (encoder + decoder) is required."
        )

    model = model.to(device)
    model.eval()

    print(f"✓ Model loaded")
    print(f"  Model type: {type(model).__name__}")

    return model


def evaluate(model, dataloader, device):
    """Run evaluation"""
    model.eval()

    all_predictions = []
    all_labels = []

    print(f"\nEvaluating on {len(dataloader.dataset)} objects...")

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            coords = batch['coords'].to(device)  # (B, N, 3)
            labels = batch['labels'].to(device)  # (B,)

            # Forward pass through complete model
            # Model signature: forward(xyz) -> logits (B, num_classes)
            logits = model(coords)

            predictions = logits.argmax(dim=-1)  # (B,)

            all_predictions.append(predictions.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    all_predictions = np.concatenate(all_predictions)
    all_labels = np.concatenate(all_labels)

    # Compute metrics
    print("\nComputing metrics...")

    overall_acc = (all_predictions == all_labels).mean() * 100

    # Per-class accuracy
    per_class_acc = {}
    for cls_idx in range(15):
        mask = (all_labels == cls_idx)
        if mask.sum() > 0:
            acc = (all_predictions[mask] == all_labels[mask]).mean() * 100
            per_class_acc[SCANOBJECTNN_CLASSES[cls_idx]] = float(acc)

    mean_class_acc = np.mean(list(per_class_acc.values()))

    results = {
        'dataset': 'scanobjectnn',
        'variant': 'PB_T50_RS',
        'overall_accuracy': float(overall_acc),
        'mean_class_accuracy': float(mean_class_acc),
        'per_class_accuracy': per_class_acc
    }

    return results


def print_results(results):
    """Print results"""
    print("\n" + "="*70)
    print(f"ScanObjectNN {results['variant']} - Evaluation Results")
    print("="*70)

    print(f"\n Overall Metrics:")
    print(f"  Overall Accuracy (OA):  {results['overall_accuracy']:.2f}%")
    print(f"  Mean Class Accuracy:    {results['mean_class_accuracy']:.2f}%")

    print(f"\n Per-Class Accuracy:")
    for cls_name, acc in results['per_class_accuracy'].items():
        print(f"  {cls_name:15s}: {acc:6.2f}%")

    print("\n" + "="*70)

    # Compare with paper
    paper_oa = 94.21
    diff = results['overall_accuracy'] - paper_oa
    print(f"\n Paper reported OA: {paper_oa:.2f}%")
    print(f" This run OA:       {results['overall_accuracy']:.2f}%")
    print(f" Difference:        {diff:+.2f}%")

    if abs(diff) < 0.5:
        print(" ✓ Results match paper!")
    else:
        print(f" ⚠ Results differ by {abs(diff):.2f}%")


def main():
    args = parse_args()

    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    print(f"\nLoading data from {args.data_root}")
    dataset = ScanObjectNNDataset(
        data_root=args.data_root,
        variant=args.variant,
        split='test',
        num_points=args.num_points
    )

    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn_scanobjectnn
    )

    # Load model
    model = load_model(args.checkpoint, device)

    # Evaluate
    results = evaluate(model, dataloader, device)

    # Print results
    print_results(results)

    # Save results
    os.makedirs(args.output, exist_ok=True)
    result_file = os.path.join(args.output, 'results.json')
    with open(result_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n Results saved to {result_file}")


if __name__ == '__main__':
    main()
