"""
BenchPOTS dataset wrapper for T1 PyPOTS-style imputation.
This module provides compatibility between BenchPOTS datasets and T1's PyPOTS implementation.
"""

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from typing import Tuple, Optional, Union, Dict
import warnings

warnings.filterwarnings('ignore')


class BenchPOTSWrapper(Dataset):
    """
    Wrapper class to use BenchPOTS datasets with T1's PyPOTS implementation.
    
    This class loads data from BenchPOTS and converts it to the format expected
    by T1's PyPOTS-style training (compatible with DatasetForMIT).
    """
    
    def __init__(
        self,
        dataset_name: str,
        subset: str = 'train',
        root_path: str = '../dataset/benchpots',
        mit_rate: float = 0.2,
        cache_mit_masks: bool = False,
        return_x_ori: bool = True,
        benchpots_missing_rate: float = 0.1,
        **kwargs
    ):
        """
        Args:
            dataset_name: Name of the BenchPOTS dataset (e.g., 'physionet2012')
            subset: Data subset ('train', 'val', 'test')
            root_path: Root path for storing BenchPOTS data
            mit_rate: MIT masking rate for training
            cache_mit_masks: Whether to cache MIT masks
            return_x_ori: Whether to return original data
            benchpots_missing_rate: Missing rate for BenchPOTS preprocessing
        """
        super().__init__()
        
        self.dataset_name = dataset_name
        self.subset = subset
        self.root_path = root_path
        self.mit_rate = mit_rate
        self.cache_mit_masks = cache_mit_masks
        self.return_x_ori = return_x_ori
        self.benchpots_missing_rate = benchpots_missing_rate
        
        # Map subset names
        self.subset_map = {
            'train': 'train',
            'val': 'val',
            'test': 'test'
        }
        
        # Initialize cache
        self.mit_mask_cache = {} if cache_mit_masks else None
        
        # Load data
        self._load_benchpots_data()
        
    def _load_benchpots_data(self):
        """Load data from BenchPOTS."""
        try:
            import benchpots
            from benchpots import datasets
            
            print(f"Loading {self.dataset_name} from BenchPOTS...")
            
            # Get the preprocessing function
            if self.dataset_name == 'physionet2012':
                from benchpots.datasets import preprocess_physionet2012
                data = preprocess_physionet2012(
                    subset='set-a',  # or 'set-b'
                    rate=self.benchpots_missing_rate
                )
            elif self.dataset_name == 'air_quality':
                from benchpots.datasets import preprocess_air_quality
                data = preprocess_air_quality(rate=self.benchpots_missing_rate)
            elif self.dataset_name == 'electricity':
                from benchpots.datasets import preprocess_electricity
                data = preprocess_electricity(rate=self.benchpots_missing_rate)
            else:
                raise ValueError(f"Unknown BenchPOTS dataset: {self.dataset_name}")
            
            # Extract data based on subset
            subset_key = self.subset_map[self.subset]
            self.X = torch.FloatTensor(data[f"{subset_key}_X"])
            
            # Handle datasets like PhysioNet2012 that don't have X_ori for training
            if f"{subset_key}_X_ori" in data:
                self.X_ori = torch.FloatTensor(data[f"{subset_key}_X_ori"])
            else:
                # For PhysioNet2012: use X as ground truth, MIT masking will be applied later
                self.X_ori = torch.FloatTensor(data[f"{subset_key}_X"]).clone()
                if self.subset == 'train' and self.dataset_name == 'physionet2012':
                    print(f"PhysioNet2012: Using train_X as ground truth, MIT rate {self.mit_rate} will be applied for training")
            
            # Calculate masks
            self.original_missing_mask = (~torch.isnan(self.X)).float()
            
            # Fill NaN with 0 for input
            self.X = torch.nan_to_num(self.X, nan=0.0)
            self.X_ori = torch.nan_to_num(self.X_ori, nan=0.0)
            
            # Get data dimensions
            self.n_samples, self.n_steps, self.n_features = self.X.shape
            
            print(f"Loaded {self.dataset_name} {self.subset} set:")
            print(f"  Shape: {self.X.shape}")
            print(f"  Original missing rate: {1 - self.original_missing_mask.mean().item():.2%}")
            
        except ImportError:
            raise ImportError(
                "BenchPOTS is not installed. Please install it with: pip install benchpots"
            )
        except Exception as e:
            raise RuntimeError(f"Error loading BenchPOTS data: {e}")
    
    def _generate_mit_mask(self, X: torch.Tensor, original_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate MIT mask for a given sample.
        
        Returns:
            mit_mask: New mask after MIT (1 for observed, 0 for missing)
            indicating_mask: Mask showing artificially masked values
        """
        # Only mask observed values
        observed_indices = original_mask.bool()
        n_observed = observed_indices.sum()
        n_to_mask = int(n_observed * self.mit_rate)
        
        if n_to_mask > 0:
            # Randomly select values to mask
            observed_flat = torch.where(observed_indices.flatten())[0]
            mask_indices = observed_flat[torch.randperm(len(observed_flat))[:n_to_mask]]
            
            # Create indicating mask
            indicating_mask = torch.zeros_like(X).flatten()
            indicating_mask[mask_indices] = 1
            indicating_mask = indicating_mask.reshape(X.shape)
            
            # Create new mask after MIT
            mit_mask = original_mask - indicating_mask
        else:
            indicating_mask = torch.zeros_like(X)
            mit_mask = original_mask
            
        return mit_mask, indicating_mask
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, ...]:
        """
        Get a sample compatible with T1's PyPOTS implementation.
        
        Returns tuple of:
            - X: Input with MIT masking applied
            - X_ori: Original data (ground truth)
            - missing_mask: Mask after MIT (1 for observed)
            - indicating_mask: Artificially masked values
            - X_mark: Time features (zeros for BenchPOTS)
            - y_mark: Time features (zeros for BenchPOTS)
        """
        # Get sample
        X_sample = self.X[index]
        X_ori_sample = self.X_ori[index]
        original_mask = self.original_missing_mask[index]
        
        if self.subset == 'train':
            # Apply MIT masking for training
            if self.cache_mit_masks and index in self.mit_mask_cache:
                mit_mask, indicating_mask = self.mit_mask_cache[index]
            else:
                mit_mask, indicating_mask = self._generate_mit_mask(X_ori_sample, original_mask)
                if self.cache_mit_masks:
                    self.mit_mask_cache[index] = (mit_mask, indicating_mask)
            
            # Apply masking
            X_mit = X_ori_sample.clone()
            X_mit[indicating_mask.bool()] = 0
        else:
            # For val/test, use original missing pattern
            X_mit = X_sample
            mit_mask = original_mask
            indicating_mask = torch.zeros_like(original_mask)
        
        # Create dummy time features (BenchPOTS doesn't provide these)
        seq_x_mark = torch.zeros(self.n_steps, 4)  # Dummy time features
        seq_y_mark = torch.zeros(self.n_steps, 4)  # Dummy time features
        
        if self.return_x_ori:
            return X_mit, X_ori_sample, mit_mask, indicating_mask, seq_x_mark, seq_y_mark
        else:
            return X_mit, X_sample, seq_x_mark, seq_y_mark
    
    def __len__(self) -> int:
        return self.n_samples


class PyPOTSCustomDataset(Dataset):
    """
    Dataset class for custom PyPOTS-format data.
    Supports both dictionary and HDF5 file formats.
    """
    
    def __init__(
        self,
        data_path: str,
        subset: str = 'train',
        mit_rate: float = 0.2,
        cache_mit_masks: bool = False,
        return_x_ori: bool = True,
        **kwargs
    ):
        """
        Args:
            data_path: Path to data file (h5) or directory containing data
            subset: Data subset ('train', 'val', 'test')
            mit_rate: MIT masking rate
            cache_mit_masks: Whether to cache MIT masks
            return_x_ori: Whether to return original data
        """
        super().__init__()
        
        self.data_path = data_path
        self.subset = subset
        self.mit_rate = mit_rate
        self.cache_mit_masks = cache_mit_masks
        self.return_x_ori = return_x_ori
        
        # Initialize cache
        self.mit_mask_cache = {} if cache_mit_masks else None
        
        # Load data
        self._load_custom_data()
    
    def _load_custom_data(self):
        """Load custom PyPOTS format data."""
        import h5py
        
        if self.data_path.endswith('.h5'):
            # Load from HDF5 file
            with h5py.File(self.data_path, 'r') as f:
                self.X = torch.FloatTensor(f[f'{self.subset}_X'][:])
                if f'{self.subset}_X_ori' in f:
                    self.X_ori = torch.FloatTensor(f[f'{self.subset}_X_ori'][:])
                else:
                    self.X_ori = self.X.clone()
        else:
            # Load from numpy files in directory
            data_dir = os.path.join(self.data_path, self.subset)
            self.X = torch.FloatTensor(np.load(os.path.join(data_dir, 'X.npy')))
            if os.path.exists(os.path.join(data_dir, 'X_ori.npy')):
                self.X_ori = torch.FloatTensor(np.load(os.path.join(data_dir, 'X_ori.npy')))
            else:
                self.X_ori = self.X.clone()
        
        # Calculate masks
        self.original_missing_mask = (~torch.isnan(self.X)).float()
        
        # Fill NaN with 0
        self.X = torch.nan_to_num(self.X, nan=0.0)
        self.X_ori = torch.nan_to_num(self.X_ori, nan=0.0)
        
        # Get dimensions
        self.n_samples, self.n_steps, self.n_features = self.X.shape
        
        print(f"Loaded custom data from {self.data_path}:")
        print(f"  Shape: {self.X.shape}")
        print(f"  Missing rate: {1 - self.original_missing_mask.mean().item():.2%}")
    
    def _generate_mit_mask(self, X: torch.Tensor, original_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate MIT mask (same as BenchPOTSWrapper)."""
        observed_indices = original_mask.bool()
        n_observed = observed_indices.sum()
        n_to_mask = int(n_observed * self.mit_rate)
        
        if n_to_mask > 0:
            observed_flat = torch.where(observed_indices.flatten())[0]
            mask_indices = observed_flat[torch.randperm(len(observed_flat))[:n_to_mask]]
            
            indicating_mask = torch.zeros_like(X).flatten()
            indicating_mask[mask_indices] = 1
            indicating_mask = indicating_mask.reshape(X.shape)
            
            mit_mask = original_mask - indicating_mask
        else:
            indicating_mask = torch.zeros_like(X)
            mit_mask = original_mask
            
        return mit_mask, indicating_mask
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, ...]:
        """Get a sample (same format as BenchPOTSWrapper)."""
        X_sample = self.X[index]
        X_ori_sample = self.X_ori[index]
        original_mask = self.original_missing_mask[index]
        
        if self.subset == 'train':
            if self.cache_mit_masks and index in self.mit_mask_cache:
                mit_mask, indicating_mask = self.mit_mask_cache[index]
            else:
                mit_mask, indicating_mask = self._generate_mit_mask(X_ori_sample, original_mask)
                if self.cache_mit_masks:
                    self.mit_mask_cache[index] = (mit_mask, indicating_mask)
            
            X_mit = X_ori_sample.clone()
            X_mit[indicating_mask.bool()] = 0
        else:
            X_mit = X_sample
            mit_mask = original_mask
            indicating_mask = torch.zeros_like(original_mask)
        
        # Dummy time features
        seq_x_mark = torch.zeros(self.n_steps, 4)
        seq_y_mark = torch.zeros(self.n_steps, 4)
        
        if self.return_x_ori:
            return X_mit, X_ori_sample, mit_mask, indicating_mask, seq_x_mark, seq_y_mark
        else:
            return X_mit, X_sample, seq_x_mark, seq_y_mark
    
    def __len__(self) -> int:
        return self.n_samples