import torch
import numpy as np
from torch.utils.data import DataLoader
# Ensure this matches the file name where you saved the loader
from data_provider.explanation_data_loader import AmortizedExplanationLoader

class Configs:
    """Mock configuration namespace to satisfy UEAloader requirements."""
    dataset = "UWaveGestureLibrary"
    model = "DNN"
    dnn_type = "FCN"
    seed = 42
    batch_size = 8
    data_root = "./data/UEA_multivariate"

def verify_pipeline():
    args = Configs()
    # Construct root path as expected by UEAloader
    root_path = f"{args.data_root}/{args.dataset}"
    
    print(f"--- Starting Verification for {args.dataset} ---")
    
    try:
        # Initialize the Amortized Loader
        # This triggers super().__init__ to load raw TS
        dataset = AmortizedExplanationLoader(
            args, 
            root_path, 
            flag='TRAIN', 
            target_set='train'
        )
        
        loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
        
        # Pull a single batch
        meta_input, meta_target = next(iter(loader))
        
        # 1. Check Shapes
        B, T, C_in = meta_input.shape
        _, _, C_out = meta_target.shape
        
        print(f"\n[PASS] Batch successfully loaded.")
        print(f"Input Shape (Raw + Gradient): {meta_input.shape}")  # Expected [B, T, 2*C]
        print(f"Target Shape (BestEnsemble):  {meta_target.shape}") # Expected [B, T, C]
        
        # 2. Verify Channel Concatenation
        if C_in == 2 * C_out:
            print(f"[PASS] Input channels ({C_in}) are correctly doubled for Raw+Saliency fusion.")
        else:
            print(f"[FAIL] Channel mismatch. Expected {2*C_out}, got {C_in}.")

        # 3. Verify Per-Instance Normalization
        t_min = meta_target.min().item()
        t_max = meta_target.max().item()
        print(f"Target Min/Max: {t_min:.4f} / {t_max:.4f}")
        
        if 0.0 <= t_min <= 0.0001 and 0.9999 <= t_max <= 1.0:
            print(f"[PASS] Targets are correctly normalized to [0, 1] per instance.")
        else:
            print(f"[WARNING] Targets range outside [0, 1]. Check _normalize_instance_wise.")

    except FileNotFoundError as e:
        print(f"\n[ERROR] Missing Stage 1 results: {e}")
    except Exception as e:
        print(f"\n[ERROR] Test failed: {e}")

if __name__ == "__main__":
    verify_pipeline()
