# data_modules.py
import torch
from torch.utils.data import DataLoader, Dataset, Subset
import pytorch_lightning as pl
import pickle
import os
import numpy as np

# Assuming WilletDatasetGen2GoPeriod is in a file named willet_dataset_gen2_processing.py
# and is in the same directory or accessible via PYTHONPATH
try:
    from willet_dataset_gen2_processing import WilletDatasetGen2GoPeriod, willet_collate_fn
except ImportError:
    print("Warning: Could not import WilletDatasetGen2GoPeriod. ECoGDataModule will not work without it.")
    WilletDatasetGen2GoPeriod = None # Placeholder
    willet_collate_fn = None

# --- Constants (ensure these match model.py and your data) ---
PAD_IDX = 0 # Example

# --- For train_decoder_text_only.py ---
class PhonemeCorpusDataset(Dataset): # Copied from train_decoder_text_only for DataModule use
    def __init__(self, data_list, max_len=200, pad_idx=PAD_IDX):
        self.data = []
        self.lengths = []
        self.max_len = max_len
        self.pad_idx = pad_idx
        for seq in data_list:
            seq_tensor = torch.tensor(seq, dtype=torch.long)
            length = len(seq_tensor)
            if length > max_len:
                seq_tensor = seq_tensor[:max_len]
                length = max_len
            
            padded_seq = torch.full((max_len,), self.pad_idx, dtype=torch.long)
            padded_seq[:length] = seq_tensor
            
            self.data.append(padded_seq)
            self.lengths.append(torch.tensor(length, dtype=torch.long))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.lengths[idx]

class TextDataModule(pl.LightningDataModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams.update(vars(hparams) if not isinstance(hparams, dict) else hparams) # Save hparams
        # Expected hparams: data_path_text_corpus (path to pickled list of lists), 
        # batch_size, num_workers, max_phoneme_seq_len_text, val_split_ratio_text

    def prepare_data(self):
        # Download or verify data if needed
        if not os.path.exists(self.hparams.data_path_text_corpus):
            raise FileNotFoundError(f"Text corpus not found at {self.hparams.data_path_text_corpus}")

    def setup(self, stage=None):
        with open(self.hparams.data_path_text_corpus, 'rb') as f:
            all_phoneme_sequences = pickle.load(f) # Expected: list of lists of phoneme IDs

        # Split data
        num_total = len(all_phoneme_sequences)
        num_val = int(self.hparams.val_split_ratio_text * num_total)
        num_train = num_total - num_val
        
        # For reproducibility, consider using a fixed random seed for splitting if needed
        # For now, simple sequential split
        self.train_sequences = all_phoneme_sequences[:num_train]
        self.val_sequences = all_phoneme_sequences[num_train:]

        self.train_dataset = PhonemeCorpusDataset(
            self.train_sequences, 
            max_len=self.hparams.max_phoneme_seq_len_text,
            pad_idx=self.hparams.get('pad_idx', PAD_IDX)
        )
        self.val_dataset = PhonemeCorpusDataset(
            self.val_sequences, 
            max_len=self.hparams.max_phoneme_seq_len_text,
            pad_idx=self.hparams.get('pad_idx', PAD_IDX)
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=True
        )
    

class ECoGDataModule(pl.LightningDataModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.current_epoch = 0  # Add epoch counter
        self.current_curriculum_max_len = self.hparams.get('curriculum_min_len', 50)
        if WilletDatasetGen2GoPeriod is None:
            raise ImportError("WilletDatasetGen2GoPeriod is not available. Cannot use ECoGDataModule.")
        if willet_collate_fn is None:
            raise ImportError("willet_collate_fn is not available. Cannot use ECoGDataModule.")

    def prepare_data(self):
        if not os.path.exists(self.hparams.data_path_ecog):
            raise FileNotFoundError(f"ECoG data not found at {self.hparams.data_path_ecog}")

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            # Load data
            with open(self.hparams.data_path_ecog, "rb") as handle:
                loadedData = pickle.load(handle)

            # Initialize full dataset
            self.full_train_dataset = WilletDatasetGen2GoPeriod(
                loadedData["train"],
                neural_data_type=self.hparams.neural_data_type,
                neural_data_location=self.hparams.neural_data_location,
                daysOnly=self.hparams.get('daysOnly_train', None),
                smoothing=self.hparams.smoothing,
                gaussianSmoothWidth=self.hparams.gaussianSmoothWidth
            )
            
            # --- New: Select a fraction of the training data for scaling experiments ---
            train_fraction = self.hparams.get('train_data_fraction', 1.0)
            if train_fraction < 1.0:
                num_train_samples = int(train_fraction * len(self.full_train_dataset))
                # Use a fixed seed for reproducibility of the random subset
                seed = self.hparams.get('seed', 42) # Use hparams seed, fallback to 42
                generator = torch.Generator().manual_seed(seed) 
                indices = torch.randperm(len(self.full_train_dataset), generator=generator).tolist()
                subset_indices = indices[:num_train_samples]
                self.full_train_dataset = Subset(self.full_train_dataset, subset_indices)
                print(f"--- SCALING LAW EXPERIMENT: Using {train_fraction*100:.1f}% of training data ({len(self.full_train_dataset)} samples) ---")

            # Initialize validation dataset
            self.val_dataset = WilletDatasetGen2GoPeriod(
                loadedData["test"],
                neural_data_type=self.hparams.neural_data_type,
                neural_data_location=self.hparams.neural_data_location,
                daysOnly=self.hparams.get('daysOnly_val', None),
                smoothing=self.hparams.smoothing,
                gaussianSmoothWidth=self.hparams.gaussianSmoothWidth
            )
            
            # Initialize current_train_dataset to full dataset
            self.current_train_dataset = self.full_train_dataset
            
            # Only apply curriculum if enabled
            if self.hparams.get('curriculum_learning_enabled', False) and \
               self.hparams.training_stage == 'joint_sequential_generation':
                # Initialize with minimum curriculum length
                self.current_curriculum_max_len = self.hparams.get('curriculum_min_len', 10)
                self.current_train_dataset = self._get_curriculum_subset(
                    self.full_train_dataset, 
                    self.current_curriculum_max_len
                )

    def _get_curriculum_subset(self, base_dataset, max_target_len):
        if not self.hparams.get('curriculum_learning_enabled', False) or \
           self.hparams.training_stage != 'joint_sequential_generation':
            return base_dataset

        # Handle if base_dataset is already a Subset (from data fractionation)
        if isinstance(base_dataset, Subset):
            original_dataset = base_dataset.dataset
            parent_indices = base_dataset.indices
        else:
            original_dataset = base_dataset
            parent_indices = list(range(len(original_dataset)))

        # Filter the parent_indices based on the curriculum length
        # Note: the indices in parent_indices are absolute w.r.t original_dataset
        indices = [
            idx for idx in parent_indices
            if original_dataset.phone_seq_lens[idx].item() <= max_target_len
        ]
        
        if not indices:  # Ensure there's at least some data
            # Fallback logic to ensure some data is always returned
            all_lengths_in_subset = {idx: original_dataset.phone_seq_lens[idx].item() for idx in parent_indices}
            if not all_lengths_in_subset: # parent subset was empty
                if len(original_dataset) > 0: return Subset(original_dataset, [parent_indices[0]] if parent_indices else [0])
                else: return base_dataset # empty
            
            min_len_in_subset = min(all_lengths_in_subset.values())
            indices = [idx for idx, length in all_lengths_in_subset.items() if length <= min_len_in_subset]

            print(f"Warning: No samples found for curriculum max_len {max_target_len}. "
                  f"Using min_len {min_len_in_subset} from the available subset instead, found {len(indices)} samples.")

        # The new subset is of the original_dataset, using the filtered final_indices
        return Subset(original_dataset, indices)

    def train_dataloader(self):
        # Calculate current curriculum length based on epoch counter
        if self.hparams.get('curriculum_learning_enabled', False) and \
           self.hparams.training_stage == 'joint_sequential_generation':
            total_curriculum_epochs = self.hparams.get('curriculum_epochs_total_sg_stage', 10)
            min_len = self.hparams.get('curriculum_min_len', 10)
            # Since max_input_length is gone, we need a reasonable fallback for curriculum.
            # A large number or a value from hparams can be used.
            max_len = self.hparams.get('curriculum_max_len', 500) 
            
            progress = min(1.0, self.current_epoch / total_curriculum_epochs) if total_curriculum_epochs > 0 else 1.0
            self.current_curriculum_max_len = int(min_len + progress * (max_len - min_len))
            
            # Update dataset with current curriculum length
            self.current_train_dataset = self._get_curriculum_subset(
                self.full_train_dataset, 
                self.current_curriculum_max_len
            )
            print(f"Curriculum: epoch {self.current_epoch}/{total_curriculum_epochs}, "
                  f"max length: {self.current_curriculum_max_len}, "
                  f"samples: {len(self.current_train_dataset)}")

        return DataLoader(
            self.current_train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            collate_fn=willet_collate_fn,
            # persistent_workers=self.hparams.num_workers > 0
            persistent_workers=False
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            collate_fn=willet_collate_fn,
            # persistent_workers=self.hparams.num_workers > 0
            persistent_workers=False
        )


class CurriculumUpdateCallback(pl.Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        # Print current curriculum state
        if isinstance(trainer.datamodule, ECoGDataModule):
            print(f"\n*** CURRICULUM STATUS: Epoch {trainer.datamodule.current_epoch}, "
                  f"Max Length: {trainer.datamodule.current_curriculum_max_len} ***\n")
    
    def on_train_epoch_end(self, trainer, pl_module):
        # Update the datamodule's epoch counter
        if isinstance(trainer.datamodule, ECoGDataModule):
            trainer.datamodule.current_epoch += 1
            print(f"\n*** CURRICULUM UPDATED: Now at epoch {trainer.datamodule.current_epoch} ***\n")
            
            # Log curriculum metrics
            if hasattr(trainer, "logger"):
                trainer.logger.log_metrics({
                    "curriculum_max_len": trainer.datamodule.current_curriculum_max_len,
                    "curriculum_dataset_size": len(trainer.datamodule.current_train_dataset)
                }, step=trainer.current_epoch)