"""
S3DIS Inference Script

Usage:
    python tools/test_s3dis_simple.py \
        --checkpoint checkpoints/s3dis_area5_best.pth \
        --data_root /path/to/S3DIS/Area_5 \
        --output results/s3dis
"""

import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import numpy as np
import argparse
import json
import os
from tqdm import tqdm
from torch.utils.data import DataLoader

from data_utils import S3DISDataset, collate_fn_s3dis
from models import GISADecoder

S3DIS_CLASSES = [
    'ceiling', 'floor', 'wall', 'beam', 'column',
    'window', 'door', 'table', 'chair', 'sofa',
    'bookcase', 'board', 'clutter'
]


def compute_iou_per_class(predictions, labels, num_classes=13):
    """Compute IoU for each class"""
    ious = []
    for cls in range(num_classes):
        pred_mask = (predictions == cls)
        label_mask = (labels == cls)

        intersection = (pred_mask & label_mask).sum()
        union = (pred_mask | label_mask).sum()

        if union == 0:
            iou = float('nan')
        else:
            iou = intersection / union

        ious.append(iou)

    return np.array(ious)


def compute_miou(predictions, labels, num_classes=13):
    """Compute mIoU"""
    ious = compute_iou_per_class(predictions, labels, num_classes)
    valid_ious = ious[~np.isnan(ious)]
    return valid_ious.mean() * 100


def parse_args():
    parser = argparse.ArgumentParser(description='S3DIS Inference')
    parser.add_argument('--checkpoint', type=str, required=True,
                        help='Checkpoint path')
    parser.add_argument('--data_root', type=str, required=True,
                        help='S3DIS data root (Area_5 directory)')
    parser.add_argument('--area', type=int, default=5,
                        help='Test area (default: 5)')
    parser.add_argument('--output', type=str, default='results/s3dis',
                        help='Output directory')
    parser.add_argument('--gpu', type=int, default=0,
                        help='GPU ID')
    return parser.parse_args()


def load_model(checkpoint_path, device):
    """Load model from checkpoint"""
    print(f"Loading checkpoint from {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    # Try to load complete model first
    if 'model' in checkpoint:
        model = checkpoint['model']
        model = model.to(device)
        model.eval()
        print(f"✓ Loaded complete model from checkpoint")
        return model

    # Otherwise, extract state dict
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint

    # Remove 'module.' prefix if present
    new_state_dict = {}
    for k, v in state_dict.items():
        key = k.replace('module.', '')
        new_state_dict[key] = v

    # Create decoder
    print("Creating GISA Decoder...")
    decoder = GISADecoder(
        in_dim=256,
        hidden_dim=256,
        num_classes=13,
        dropout=0.55,
        use_residual=True,
        deltanet_hidden_dim=512,
        num_heads=4,
        scan_mode='identity'  # S3DIS uses identity mode
    )

    # Extract decoder weights
    decoder_state_dict = {}
    for k, v in new_state_dict.items():
        if k.startswith('decoder.'):
            decoder_state_dict[k.replace('decoder.', '')] = v
        elif not k.startswith('encoder.'):
            decoder_state_dict[k] = v

    missing_keys, unexpected_keys = decoder.load_state_dict(decoder_state_dict, strict=False)

    if missing_keys:
        print(f"Warning: Missing keys: {len(missing_keys)}")
    if unexpected_keys:
        print(f"Warning: Unexpected keys: {len(unexpected_keys)}")

    decoder = decoder.to(device)
    decoder.eval()

    print(f"✓ Model loaded")
    return decoder


def evaluate(model, dataloader, device):
    """Run evaluation on all scenes"""
    model.eval()

    all_predictions = []
    all_labels = []

    print(f"\nEvaluating on {len(dataloader.dataset)} scenes...")

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            coords = batch['coords'].to(device)
            features = batch['features'].to(device)
            labels = batch['labels'].to(device)

            # Check if complete model or decoder-only
            # Complete model: forward(xyz, features, indices, pts_list, labels)
            # Decoder only: forward(features, xyz=coords)

            # Try decoder-only first (simpler case)
            try:
                logits = model(features, xyz=coords)  # (N, num_classes)
            except:
                # If failed, might be complete model - not supported in inference package
                raise RuntimeError(
                    "Failed to run inference. This package provides decoder-only code. "
                    "Ensure your checkpoint contains the GISA decoder and data has pre-extracted features."
                )

            predictions = logits.argmax(dim=-1)  # (N,)

            all_predictions.append(predictions.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    all_predictions = np.concatenate(all_predictions)
    all_labels = np.concatenate(all_labels)

    # Compute metrics
    print("\nComputing metrics...")

    miou = compute_miou(all_predictions, all_labels, num_classes=13)
    per_class_ious = compute_iou_per_class(all_predictions, all_labels, num_classes=13)

    overall_acc = (all_predictions == all_labels).mean() * 100

    results = {
        'dataset': 's3dis',
        'area': 5,
        'miou': float(miou),
        'overall_accuracy': float(overall_acc),
        'per_class_iou': {
            S3DIS_CLASSES[i]: float(iou * 100) if not np.isnan(iou) else 0.0
            for i, iou in enumerate(per_class_ious)
        }
    }

    return results


def print_results(results):
    """Print results"""
    print("\n" + "="*70)
    print(f"S3DIS Area {results['area']} - Evaluation Results")
    print("="*70)

    print(f"\n Overall Metrics:")
    print(f"  mIoU:     {results['miou']:.2f}%")
    print(f"  Accuracy: {results['overall_accuracy']:.2f}%")

    print(f"\n Per-Class IoU:")
    for cls_name, iou in results['per_class_iou'].items():
        print(f"  {cls_name:15s}: {iou:6.2f}%")

    print("\n" + "="*70)

    # Compare with paper
    paper_miou = 82.62
    diff = results['miou'] - paper_miou
    print(f"\n Paper reported mIoU: {paper_miou:.2f}%")
    print(f" This run mIoU:       {results['miou']:.2f}%")
    print(f" Difference:          {diff:+.2f}%")

    if abs(diff) < 1.0:
        print(" ✓ Results match paper!")
    else:
        print(f" ⚠ Results differ by {abs(diff):.2f}%")


def main():
    args = parse_args()

    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    print(f"\nLoading data from {args.data_root}")
    dataset = S3DISDataset(
        data_root=args.data_root,
        area=args.area,
        voxel_size=0.04
    )

    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn_s3dis
    )

    # Load model
    model = load_model(args.checkpoint, device)

    # Evaluate
    results = evaluate(model, dataloader, device)

    # Print results
    print_results(results)

    # Save results
    os.makedirs(args.output, exist_ok=True)
    result_file = os.path.join(args.output, 'results.json')
    with open(result_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n Results saved to {result_file}")


if __name__ == '__main__':
    main()
