import os
from typing import List
import numpy as np
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
import pandas as pd
from torch.utils.data import DataLoader, Subset
from pytorch_lightning import LightningDataModule
from pathlib import Path
from typing import List, Sequence, Optional, Dict, Union
import torch
import pandas as pd          
import pytorch_lightning as pl
from torch.utils.data import DataLoader

from melp.datasets.pretrain_dataset import SleepEpochDataset    
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from typing import List, Sequence, Optional, Dict, Union
from pathlib import Path
from melp.datasets.pretrain_dataset import SleepEpochDataset



class SleepDataModule(LightningDataModule):

    def __init__(
        self,
        csv_dir: str | Path,
        *,
        is_pretrain,
        data_pct = 1,
        val_dataset_list: Optional[List[str]] = None,
        downstream_dataset_name  = None,
        batch_size: int = 128,
        num_workers: int = 4,
        patient_cols: Optional[Union[str, Sequence[str]]] = None,
        event_cols: Optional[Union[str, Sequence[str]]] = None,
        train_edf_cols: Sequence[str] | None,  # passed to Dataset
        transforms=None,
        n_views: int = 1,
        cache_size: int = 8,                   # passed to Dataset
        sample_rate: int = 128,
        window_size: int = 30,
        pin_memory: bool = False,
        persistent_workers: bool = False,
        # ===== NEW: unified CSV loading =====
        data_source: str = "auto",  # "auto", "pretrain", "downstream", "both"
        include_datasets: Optional[List[str]] = None,
        # ===== NEW: regression task =====
        regression_targets: Optional[List[str]] = None,  # e.g., ["HR", "SPO2"]
        # ===== NEW: regression value filtering =====
        regression_filter_config: Optional[Dict] = None,  # e.g., {"SPO2_mean": {"min": 70}, "HR_mean": {"min": 30, "max": 200}}
        # ===== NEW: few-shot by sample count =====
        n_train_samples: Optional[int] = None,  # if set, use exactly this many training samples (overrides data_pct)
        val_batch_size: Optional[int] = None,  # if set, use this batch_size for val/test (useful for few-shot)
        val_data_pct: Optional[float] = None,  # if set, use this percentage of val/test data (useful for few-shot)
        # ===== NEW: multi-label visualization =====
        return_all_event_cols: bool = False,  # When True, load and return all event_cols
        # ===== NEW: demographic mapping =====
        return_nsrrid: bool = False,  # When True, include nsrrid in output
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["transforms"])  
        self.downstream_dataset_name  = downstream_dataset_name
        self.csv_dir   = csv_dir
        self.transforms = transforms
        self.n_views    = n_views
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.is_pretrain = is_pretrain
        self.patient_cols = patient_cols
        self.event_cols = event_cols
        self.data_pct = data_pct
        self.data_source = data_source
        self.include_datasets = include_datasets
        self.regression_targets = regression_targets  # NEW
        self.regression_filter_config = regression_filter_config  # NEW
        self.n_train_samples = n_train_samples  # NEW: few-shot by sample count
        self.val_batch_size = val_batch_size  # NEW: separate batch_size for val/test
        self.val_data_pct = val_data_pct  # NEW: subsample val/test data
        self.return_all_event_cols = return_all_event_cols  # NEW: multi-label viz
        self.return_nsrrid = return_nsrrid  # NEW: demographic mapping


    # ---------- 3. DataLoader ----------
    def train_dataloader(self):
        if self.is_pretrain == 1:
            train_set = SleepEpochDataset(
                    csv_dir       = self.csv_dir,
                    split         = "pretrain",
                    data_pct      = self.data_pct,  # now configurable
                    train_edf_cols= self.hparams.train_edf_cols,
                    transform     = self.transforms,
                    sample_rate   = self.hparams.sample_rate,
                    window_size   = self.hparams.window_size,
                    cache_size    = self.hparams.cache_size,
                    data_source   = self.data_source,
                    include_datasets = self.include_datasets,
                )
            persistent_workers = self.persistent_workers
        else:
            train_set = SleepEpochDataset(
                    csv_dir       = self.csv_dir,
                    split         = "train",
                    data_pct      = self.data_pct,
                    patient_cols  = self.patient_cols,
                    event_cols    = self.event_cols,
                    train_edf_cols= self.hparams.train_edf_cols,
                    transform     = self.transforms,
                    sample_rate   = self.hparams.sample_rate,
                    window_size   = self.hparams.window_size,
                    cache_size    = self.hparams.cache_size,
                    downstream_dataset_name  = self.downstream_dataset_name,
                    data_source   = self.data_source,
                    include_datasets = self.include_datasets,
                    regression_targets = self.regression_targets,  # NEW
                    regression_filter_config = self.regression_filter_config,  # NEW
                    return_all_event_cols = self.return_all_event_cols,  # NEW: multi-label
                    return_nsrrid = self.return_nsrrid,  # NEW: demographic mapping
                )
        # Store dataset object for class distribution access
            self._train_dataset = train_set
            persistent_workers = True
        
        # ===== NEW: few-shot by k-shot per class =====
        if self.n_train_samples is not None and self.n_train_samples > 0:
            n_total = len(train_set)
            rng = np.random.default_rng(seed=42)
            
            # Try stratified sampling (k-shot per class)
            if hasattr(train_set, 'event_cols') and train_set.event_cols and hasattr(train_set, 'all_epoch_df'):
                label_col = train_set.event_cols[0]
                if label_col in train_set.all_epoch_df.columns:
                    labels = train_set.all_epoch_df[label_col].values
                    num_classes = getattr(train_set, 'num_classes', None)
                    
                    if num_classes is not None:
                        # k-shot per class: sample n_train_samples from each class
                        all_indices = []
                        for c in range(num_classes):
                            class_indices = np.where(labels == c)[0]
                            n_per_class = min(self.n_train_samples, len(class_indices))
                            if n_per_class > 0:
                                sampled = rng.choice(class_indices, size=n_per_class, replace=False)
                                all_indices.extend(sampled.tolist())
                                print(f"[Few-shot] Class {c}: sampled {n_per_class}/{len(class_indices)} samples")
                        
                        indices = all_indices
                        train_set = Subset(train_set, indices)
                        print(f"[Few-shot] Total: {len(indices)}/{n_total} samples ({self.n_train_samples}-shot per class)")
                    else:
                        # Fallback: random sampling
                        n_keep = min(self.n_train_samples, n_total)
                        indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
                        train_set = Subset(train_set, indices)
                        print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
                else:
                    # Fallback: random sampling
                    n_keep = min(self.n_train_samples, n_total)
                    indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
                    train_set = Subset(train_set, indices)
                    print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
            else:
                # Fallback: random sampling (for pretrain or no labels)
                n_keep = min(self.n_train_samples, n_total)
                indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
                train_set = Subset(train_set, indices)
                print(f"[Few-shot] Using {n_keep}/{n_total} training samples (random, n_train_samples={self.n_train_samples})")
        
        return DataLoader(
            train_set,
            batch_size     = self.hparams.batch_size,
            shuffle        = True,
            num_workers    = self.hparams.num_workers,
            pin_memory     = self.pin_memory,
            persistent_workers = persistent_workers,
            drop_last      = True,
            
            # persistent_workers = self.hparams.num_workers > 0,
        )
    
    def get_class_distribution(self) -> Optional[torch.Tensor]:
        """
        Get class distribution from training dataset.
        Returns [num_classes] tensor of class counts, or None if not available.
        """
        if hasattr(self, '_train_dataset'):
            counts = self._train_dataset.get_class_counts()
            if counts is not None:
                return torch.from_numpy(counts).float()
        return None

    def val_dataloader(self):
        if self.hparams.val_dataset_list:       
            if self.is_pretrain == 1:
                val_sets = [
                        SleepEpochDataset(
                            csv_dir       = self.csv_dir,
                            split         = "pretrain-val",
                            data_pct      = self.data_pct,
                            patient_cols   = self.patient_cols,
                            event_cols   = self.event_cols,
                            train_edf_cols= self.hparams.train_edf_cols,
                            transform     = None,        
                            sample_rate   = self.hparams.sample_rate,
                            window_size   = self.hparams.window_size,
                            cache_size    = self.hparams.cache_size,
                            downstream_dataset_name  = _,
                            data_source   = self.data_source,
                            include_datasets = self.include_datasets,
                        )
                        for _ in self.hparams.val_dataset_list
                    ]
                persistent_workers = self.persistent_workers
        else:
            if self.is_pretrain == 1:
                val_sets = [
                    SleepEpochDataset(
                        csv_dir       = self.csv_dir,
                        split         = "pretrain-val",
                        data_pct      = self.data_pct,
                        patient_cols   = self.patient_cols,
                        event_cols   = self.event_cols,
                        train_edf_cols= self.hparams.train_edf_cols,
                        transform     = None,
                        sample_rate   = self.hparams.sample_rate,
                        window_size   = self.hparams.window_size,
                        cache_size    = self.hparams.cache_size,
                        data_source   = self.data_source,
                        include_datasets = self.include_datasets,
                    )
                    ]
                persistent_workers = self.persistent_workers
            else:
                val_sets = [
                    SleepEpochDataset(
                        csv_dir       = self.csv_dir,
                        split         = "val",
                        data_pct      = self.data_pct,
                        patient_cols   = self.patient_cols,
                        event_cols   = self.event_cols,
                        train_edf_cols= self.hparams.train_edf_cols,
                        transform     = None,
                        sample_rate   = self.hparams.sample_rate,
                        window_size   = self.hparams.window_size,
                        cache_size    = self.hparams.cache_size,
                        downstream_dataset_name  = self.downstream_dataset_name,
                        data_source   = self.data_source,
                        include_datasets = self.include_datasets,
                        regression_targets = self.regression_targets,  # NEW
                        regression_filter_config = self.regression_filter_config,  # NEW
                    )
                    ]
                persistent_workers = True
        
        # Subsample val sets if val_data_pct is set
        if self.val_data_pct is not None and 0 < self.val_data_pct < 1.0:
            subsampled_val_sets = []
            for ds in val_sets:
                n_total = len(ds)
                n_keep = max(1, int(n_total * self.val_data_pct))
                rng = np.random.default_rng(seed=42)
                indices = rng.choice(n_total, size=n_keep, replace=False).tolist()
                subsampled_val_sets.append(Subset(ds, indices))
                print(f"[Val subsample] Using {n_keep}/{n_total} val samples ({self.val_data_pct*100:.1f}%)")
            val_sets = subsampled_val_sets
        
        # Use val_batch_size if set, otherwise use batch_size
        val_bs = self.val_batch_size if self.val_batch_size is not None else self.hparams.batch_size
        return [
            DataLoader(
                ds,
                batch_size     = val_bs,
                shuffle        = False,
                num_workers    = self.hparams.num_workers,
                pin_memory     = self.pin_memory,
                persistent_workers = persistent_workers,
                drop_last      = True,
                # persistent_workers = self.hparams.num_workers > 0,
            )
            for ds in val_sets
        ]

    def test_dataloader(self):
        if self.is_pretrain == 1:
            test_set = SleepEpochDataset(
                csv_dir       = self.csv_dir,
                split         = "pretrain-test",
                patient_cols   = self.patient_cols,
                event_cols   = self.event_cols,
                train_edf_cols= self.hparams.train_edf_cols,
                transform     = None,
                sample_rate   = self.hparams.sample_rate,
                window_size   = self.hparams.window_size,
                cache_size    = self.hparams.cache_size,
                data_source   = self.data_source,
                include_datasets = self.include_datasets,
                exclude_datasets = self.exclude_datasets,
            )
            persistent_workers = self.persistent_workers
        else:
            test_set = SleepEpochDataset(
                    csv_dir       = self.csv_dir,
                    split         = "test",
                    patient_cols   = self.patient_cols,
                    event_cols   = self.event_cols,
                    train_edf_cols= self.hparams.train_edf_cols,
                    transform     = None,
                    sample_rate   = self.hparams.sample_rate,
                    window_size   = self.hparams.window_size,
                    cache_size    = self.hparams.cache_size,
                    downstream_dataset_name  = self.downstream_dataset_name,
                    data_source   = self.data_source,
                    include_datasets = self.include_datasets,
                    regression_targets = self.regression_targets,  # NEW
                    regression_filter_config = self.regression_filter_config,  # NEW
                )
            persistent_workers = True
        # Use val_batch_size if set, otherwise use batch_size
        test_bs = self.val_batch_size if self.val_batch_size is not None else self.hparams.batch_size
        return DataLoader(
            test_set,
            batch_size     = test_bs,
            shuffle        = False, # to enable ddp's sampler + set_epoch()
            num_workers    = self.hparams.num_workers,
            pin_memory     = self.pin_memory,
            drop_last      = True,
            persistent_workers = persistent_workers,
            # persistent_workers = self.hparams.num_workers > 0,
        )

