import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

class PosePenetrationDataset(Dataset):
    def __init__(self, root_dir, split='train', min_motion_length=40, max_motion_length=200, mean=None, std=None, task_type='binary', penetration_threshold=0.01):
        
        self.root_dir = Path(root_dir)
        self.split = split
        self.min_motion_length = min_motion_length
        self.max_motion_length = max_motion_length
        self.mean = mean
        self.std = std
        self.task_type = task_type
        self.penetration_threshold = penetration_threshold
        
        
        split_file = self.root_dir / f"{split}.txt"
        with open(split_file, 'r') as f:
            self.motion_names = [line.strip() for line in f.readlines()]
        
        
        self.data = []
        for name in self.motion_names:
            
            pose_file = self.root_dir / "ori_vecs" / f"{name}.npy"
            if not pose_file.exists():
                continue
            poses = np.load(pose_file)  # [L, C]

            
            if len(poses) < self.min_motion_length or len(poses) > self.max_motion_length:
                continue
            
            if self.task_type == 'binary':
                
                mask_file = self.root_dir / "penetration_mask" / f"{name}.npy"
                if not mask_file.exists():
                    continue
                masks = torch.load(mask_file)  # [L]
                
                
                for i in range(len(poses)):
                    self.data.append({
                        'pose': poses[i],
                        'label': masks[i],
                        'motion_name': name,
                        'frame_idx': i
                    })
                    
            elif self.task_type == 'joint_score':
                
                joint_score_file = self.root_dir / "penetration_score" / f"{name}_penetration_score.npy"
                if not joint_score_file.exists():
                    continue
                joint_scores = np.load(joint_score_file)  # [L, 22]
                
                
                for i in range(len(poses)):
                    self.data.append({
                        'pose': poses[i],
                        'label': joint_scores[i],  # [22]
                        'motion_name': name,
                        'frame_idx': i
                    })
                    
            elif self.task_type == 'joint_binary':
                
                joint_score_file = self.root_dir / "penetration_score" / f"{name}_penetration_score.npy"
                if not joint_score_file.exists():
                    continue
                joint_scores = np.load(joint_score_file)  # [L, 22]
                
                
                for i in range(len(poses)):
                    
                    binary_labels = (joint_scores[i] > self.penetration_threshold).astype(np.int64)
                    self.data.append({
                        'pose': poses[i],
                        'label': binary_labels,  
                        'motion_name': name,
                        'frame_idx': i
                    })
                    
            elif self.task_type == 'joint_joint_binary':
                
                joint_penetration_file = self.root_dir / "joint_joint" / f"{name}_joint_penetration.npy"
                if not joint_penetration_file.exists():
                    continue
                joint_penetration_matrices = np.load(joint_penetration_file)  # [L, 22, 22]
                
                
                for i in range(len(poses)):
                    
                    joint_joint_matrix = joint_penetration_matrices[i].astype(np.int64)
                    
                    self.data.append({
                        'pose': poses[i],
                        'label': joint_joint_matrix,  
                        'motion_name': name,
                        'frame_idx': i
                    })
            else:
                raise ValueError(f"Unknown task_type: {self.task_type}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        pose = item['pose']
        
        
        if self.mean is not None and self.std is not None:
            pose = (pose - self.mean) / self.std
            
        return {
            'pose': pose,
            'label': torch.tensor(item['label'], dtype=torch.float32 if self.task_type == 'joint_score' else torch.long),
            'motion_name': item['motion_name'],
            'frame_idx': item['frame_idx']
        }

def get_dataloader(cfg, split='train'):
    mean = None
    std = None
    if hasattr(cfg.dataset, 'mean_std_path'):
        mean_std_path = Path(cfg.dataset.mean_std_path)
        if mean_std_path.exists():
            mean = np.load(mean_std_path / 'Mean.npy')
            std = np.load(mean_std_path / 'Std.npy')
    
    task_type = getattr(cfg.dataset, 'task_type', 'binary')
    
    penetration_threshold = getattr(cfg.dataset, 'penetration_threshold', 0.01)
    
    dataset = PosePenetrationDataset(
        root_dir=cfg.dataset.root_dir,
        split=split,
        min_motion_length=cfg.dataset.min_motion_length,
        max_motion_length=cfg.dataset.max_motion_length,
        mean=mean,
        std=std,
        task_type=task_type,
        penetration_threshold=penetration_threshold
    )
    
    return DataLoader(
        dataset,
        batch_size=cfg.train.batch_size,
        shuffle=(split == 'train'),
        num_workers=4,
        pin_memory=True
    ) 