import torch
import numpy as np
import os
from torch.utils.data import Dataset
from data_provider.bottom_up_mask_loader import BottomUpMaskDataset
# ======================== Verification Script ========================

class MockArgs:
    dataset = "UWaveGestureLibrary"
    model = "DNN"
    dnn_type = "FCN"
    seed = 42

def verify_dataset(percentage):
    args = MockArgs()
    k_test = percentage # Verification for given percentage masking
    
    print(f"--- Verifying BottomUpMaskLoader: {args.dataset} (k={k_test}%) ---")
    
    try:
        loader = BottomUpMaskDataset(args, k_percentage=k_test, flag='test')
    except Exception as e:
        print(f"Initialization Failed: {e}")
        return

    if len(loader) == 0:
        print("Error: No samples found. Ensure Stage 1 .npy files exist.")
        return

    # Fetch a sample for testing
    sample_idx = 0
    x, mask = loader[sample_idx]
    
    # 1. Verify Sample Count and Shapes
    print(f"Total samples in dataset: {len(loader)}")
    print(f"Sample Input Shape: {x.shape} (Time, Channels)")
    print(f"Sample Mask Shape:  {mask.shape}")
    
    # 2. Verify Normalization
    x_min, x_max = x.min().item(), x.max().item()
    norm_ok = (x_min >= 0.0 and x_max <= 1.0)
    print(f"Normalization Check: Min={x_min:.4f}, Max={x_max:.4f} -> {'PASS' if norm_ok else 'FAIL'}")
    
    # 3. Verify Mask Zero Count (Faithfulness Logic)
    total_features = x.numel()
    expected_zeros = int(total_features * (k_test / 100.0))
    actual_zeros = int((mask == 0).sum())
    
    mask_ok = (actual_zeros == expected_zeros)
    print(f"Expected Zeros: {expected_zeros}")
    print(f"Actual Zeros in Mask: {actual_zeros} -> {'PASS' if mask_ok else 'FAIL'}")

    # 4. Content Verification
    unique_vals = torch.unique(mask)
    binary_ok = torch.all((unique_vals == 0) | (unique_vals == 1))
    print(f"Mask is Binary (0s and 1s): {'PASS' if binary_ok else 'FAIL'}")

if __name__ == "__main__":
    verify_dataset(5)   # Test for 5% masking
