"""
PM25 dataset wrapper for T1 PyPOTS-style imputation.
This module provides compatibility between CSDI PM25 dataset and T1's PyPOTS implementation.
"""

import os
import pickle
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from typing import Tuple, Optional, Dict


class PM25Wrapper(Dataset):
    """
    Wrapper class to use CSDI PM25 dataset with T1's PyPOTS implementation.
    
    The PM25 dataset contains air quality data from 36 monitoring stations with:
    - Fixed train/test split by months
    - Ground truth data with natural missing values (13.25%)
    - Pre-defined artificial missing values for evaluation (24.58% total)
    """
    
    def __init__(
        self,
        subset: str = 'train',
        root_path: str = '/ssd/datasets/TimeSeries/pm25',
        eval_length: int = 36,
        target_dim: int = 36,
        return_x_ori: bool = True,
        validindex: int = 0,
        **kwargs
    ):
        """
        Args:
            subset: Data subset ('train', 'val', 'test')
            root_path: Root path for PM25 data
            eval_length: Sequence length for evaluation
            target_dim: Number of features (monitoring stations)
            return_x_ori: Whether to return original data
            validindex: Which month to use for validation (0-7)
        """
        super().__init__()
        
        self.subset = subset
        self.root_path = root_path
        self.eval_length = eval_length
        self.target_dim = target_dim
        self.return_x_ori = return_x_ori
        self.validindex = validindex
        
        # Load mean/std for normalization
        meanstd_path = os.path.join(root_path, "pm25_meanstd.pk")
        with open(meanstd_path, "rb") as f:
            self.train_mean, self.train_std = pickle.load(f)
        
        # Define month splits (following CSDI)
        if subset == "train":
            month_list = [1, 2, 4, 5, 7, 8, 10, 11]
            month_list.pop(validindex)  # Remove validation month
        elif subset == "val":
            month_list = [1, 2, 4, 5, 7, 8, 10, 11]
            month_list = month_list[validindex : validindex + 1]
        elif subset == "test":
            month_list = [3, 6, 9, 12]
        
        self.month_list = month_list
        
        # Load data
        self._load_pm25_data()
        
    def _load_pm25_data(self):
        """Load PM25 data from CSDI format."""
        # Load ground truth and missing data
        ground_path = os.path.join(
            self.root_path, "Code/STMVL/SampleData/pm25_ground.txt"
        )
        missing_path = os.path.join(
            self.root_path, "Code/STMVL/SampleData/pm25_missing.txt"
        )
        
        df_ground = pd.read_csv(ground_path, index_col="datetime", parse_dates=True)
        df_missing = pd.read_csv(missing_path, index_col="datetime", parse_dates=True)
        
        # Collect data for selected months
        all_data = []
        all_data_ori = []
        all_masks = []
        all_indicating_masks = []
        
        for month in self.month_list:
            # Get data for this month
            month_ground = df_ground[df_ground.index.month == month]
            month_missing = df_missing[df_missing.index.month == month]
            
            # Convert to numpy
            ground_values = month_ground.values
            missing_values = month_missing.values
            
            # Create masks
            ground_mask = ~np.isnan(ground_values)
            missing_mask = ~np.isnan(missing_values)
            
            # Indicating mask: positions that are artificially masked
            # 1 where ground has value but missing doesn't
            indicating_mask = ground_mask & ~missing_mask
            
            # Normalize using train statistics
            ground_normalized = (ground_values - self.train_mean) / self.train_std
            missing_normalized = (missing_values - self.train_mean) / self.train_std
            
            # Keep NaN values for natural missing - PyPOTS will handle them
            # PyPOTS's fill_and_get_mask_torch will automatically fill NaN with 0
            
            # Create sliding windows
            n_windows = len(month_ground) - self.eval_length + 1
            if n_windows > 0:
                for i in range(n_windows):
                    all_data.append(missing_normalized[i:i + self.eval_length])
                    all_data_ori.append(ground_normalized[i:i + self.eval_length])
                    all_masks.append(missing_mask[i:i + self.eval_length])
                    all_indicating_masks.append(indicating_mask[i:i + self.eval_length])
        
        # Convert to tensors
        self.X = torch.FloatTensor(np.array(all_data))
        self.X_ori = torch.FloatTensor(np.array(all_data_ori))
        self.missing_mask = torch.FloatTensor(np.array(all_masks))
        self.indicating_mask = torch.FloatTensor(np.array(all_indicating_masks))
        
        # Get dimensions
        self.n_samples, self.n_steps, self.n_features = self.X.shape
        
        print(f"Loaded PM25 {self.subset} set:")
        print(f"  Shape: {self.X.shape}")
        print(f"  Natural + artificial missing rate: {1 - self.missing_mask.mean().item():.2%}")
        print(f"  Artificial missing rate: {self.indicating_mask.mean().item():.2%}")
        
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, ...]:
        """
        Get a sample compatible with T1's PyPOTS implementation.
        
        Returns tuple of:
            - X: Input with missing values
            - X_ori: Original data (ground truth)
            - missing_mask: Mask for observed values (1 for observed)
            - indicating_mask: Artificially masked values
            - X_mark: Time features (zeros for PM25)
            - y_mark: Time features (zeros for PM25)
        """
        # Get sample
        X_sample = self.X[index]
        X_ori_sample = self.X_ori[index]
        mask_sample = self.missing_mask[index]
        indicating_sample = self.indicating_mask[index]
        
        # Create dummy time features (PM25 doesn't use time features)
        seq_x_mark = torch.zeros(self.n_steps, 4)
        seq_y_mark = torch.zeros(self.n_steps, 4)
        
        if self.return_x_ori:
            return X_sample, X_ori_sample, mask_sample, indicating_sample, seq_x_mark, seq_y_mark
        else:
            return X_sample, X_sample, seq_x_mark, seq_y_mark
    
    def __len__(self) -> int:
        return self.n_samples