import os
import torch
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import numpy as np
from mmap_dataset import MMapIndexedDataset


class MMapPythiaDataset(Dataset):
    """PyTorch Dataset wrapper for MMapIndexedDataset."""

    def __init__(
        self,
        path,
        start_idx=0,
        end_idx=None,
        equiv_lookup=None,
        mask_bigram_with_self=False,
        mask_bigram_loss=False,
        mask_bigram_loss_with_self_included=False,
    ):
        """Initialize dataset with MMapIndexedDataset."""
        self.dataset = MMapIndexedDataset(path, skip_warmup=True)
        self.start_idx = start_idx
        self.end_idx = end_idx if end_idx is not None else len(self.dataset) - 1
        self.seq_length = 2049
        self.equiv_lookup = equiv_lookup
        self.mask_bigram_with_self = mask_bigram_with_self
        self.mask_bigram_loss = mask_bigram_loss
        self.mask_bigram_loss_with_self_included = mask_bigram_loss_with_self_included

    def __len__(self):
        return self.end_idx - self.start_idx + 1

    def _compute_bigram_mask(self, tokens):
        """Compute bigram mask with equivalence classes."""
        S = len(tokens)
        loss_mask = torch.ones(S, dtype=torch.bool)

        if S <= 2 or self.equiv_lookup is None:
            return loss_mask

        
        first_tokens = tokens[:-1]  
        second_tokens = tokens[1:]  

        
        equiv_first = self.equiv_lookup[
            first_tokens.unsqueeze(1), first_tokens.unsqueeze(0)
        ]
        equiv_second = self.equiv_lookup[
            second_tokens.unsqueeze(1), second_tokens.unsqueeze(0)
        ]

        
        bigram_equiv = equiv_first & equiv_second

        
        valid_mask = torch.tril(torch.ones(S - 1, S - 1, dtype=torch.bool), diagonal=-1)

        
        bigram_equiv = bigram_equiv & valid_mask

        
        has_match = bigram_equiv.any(dim=1)

        
        match_positions = torch.where(has_match)[0]

        
        if len(match_positions) > 0:
            loss_mask[match_positions + 1] = False  
            if self.mask_bigram_with_self:
                raise ValueError("mask_bigram_with_self has not been tested yet")
                loss_mask[torch.clamp(match_positions, min=0)] = False

        return loss_mask
    
    def bigram_loss_mask_exact(
            self, tokens
        ):
        """Compute bigram mask with exact matching (no equivalence lookup)."""

        S = len(tokens)
        loss_mask = torch.ones(S, dtype=torch.bool)

        if S <= 2:
            return loss_mask

        
        first_tokens = tokens[:-1]   
        second_tokens = tokens[1:]   

        
        bigram_first = first_tokens.unsqueeze(1) == first_tokens.unsqueeze(0)   
        bigram_second = second_tokens.unsqueeze(1) == second_tokens.unsqueeze(0)  

        
        bigram_equiv = bigram_first & bigram_second

        
        valid_mask = torch.tril(torch.ones(S - 1, S - 1, dtype=torch.bool), diagonal=-1)

        
        bigram_equiv = bigram_equiv & valid_mask

        
        has_match = bigram_equiv.any(dim=1)

        
        match_positions = torch.where(has_match)[0]

        
        if len(match_positions) > 0:
            loss_mask[match_positions + 1] = False  
            if getattr(self, "mask_bigram_with_self", False):
                raise ValueError("mask_bigram_with_self has not been tested yet")
                loss_mask[torch.clamp(match_positions, min=0)] = False

        return loss_mask


    def __getitem__(self, idx):
        """Get training sample with bigram masking applied."""
        absolute_idx = idx + self.start_idx
        sequence = self.dataset[absolute_idx]

        if isinstance(sequence, np.ndarray) and len(sequence.shape) == 2:
            sequence = sequence[0]

        tokens = torch.tensor(sequence, dtype=torch.long)
        input_ids = tokens[:-1]
        labels = tokens[:-1]

        
        if self.equiv_lookup is not None:
            loss_mask = self._compute_bigram_mask(input_ids)
            
            labels = labels.clone()
            labels[~loss_mask] = -100  
        elif self.mask_bigram_loss:
            loss_mask = self.bigram_loss_mask_exact(input_ids)
            labels = labels.clone()
            labels[~loss_mask] = -100  

        attention_mask = torch.ones_like(input_ids)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


"""class MMapPythiaDatasetShuffled(Dataset):
    def __init__(self, path, indices):
        self.dataset = MMapIndexedDataset(path, skip_warmup=True)
        self.indices = indices
        self.seq_length = 2049

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        absolute_idx = self.indices[idx]
        sequence = self.dataset[absolute_idx]

        if isinstance(sequence, np.ndarray) and len(sequence.shape) == 2:
            sequence = sequence[0]

        tokens = torch.tensor(sequence, dtype=torch.long)
        input_ids = tokens[:-1]
        labels = tokens[1:]
        
        attention_mask = torch.ones_like(input_ids)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }"""


class MMapPythiaDataModule(pl.LightningDataModule):
    """PyTorch Lightning DataModule for Pythia model using MMapIndexedDataset."""

    def __init__(
        self,
        train_path,
        valid_path=None,
        batch_size=16,
        num_workers=4,
        train_valid_split=0.95,
        single_file=True,
        shuffle=False,
        equiv_lookup=None,
        mask_bigram_with_self=False,
        mask_bigram_loss=False,
        mask_bigram_loss_with_self_included=False,
    ):
        """Initialize DataModule."""
        super().__init__()
        self.train_path = train_path
        self.valid_path = valid_path if valid_path else train_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_valid_split = train_valid_split
        self.single_file = single_file
        self.shuffle = shuffle
        self.equiv_lookup = equiv_lookup
        self.mask_bigram_with_self = mask_bigram_with_self
        self.mask_bigram_loss = mask_bigram_loss
        self.mask_bigram_loss_with_self_included = mask_bigram_loss_with_self_included

        self.train_dataset = None
        self.valid_dataset = None

    def prepare_data(self):
        """Check if dataset exists with detailed debugging."""
        train_idx_path = self.train_path + ".idx"

        
        print(f"Checking for dataset at {self.train_path}")
        print(f"Full idx path: {train_idx_path}")
        print(f"Current working directory: {os.getcwd()}")

        
        directory = os.path.dirname(self.train_path)
        print(f"Directory exists: {os.path.exists(directory)}")

        if os.path.exists(directory):
            
            print(f"Files in directory {directory}:")
            files = os.listdir(directory)
            for file in files:
                print(f"  - {file}")

        
        idx_exists = os.path.exists(train_idx_path)
        bin_exists = os.path.exists(self.train_path + ".bin")
        print(f"IDX file exists: {idx_exists}")
        print(f"BIN file exists: {bin_exists}")

        
        alt_path = os.path.join(directory, "document")
        alt_idx_exists = os.path.exists(alt_path + ".idx")
        alt_bin_exists = os.path.exists(alt_path + ".bin")
        print(f"Alternative path: {alt_path}")
        print(f"Alt IDX file exists: {alt_idx_exists}")
        print(f"Alt BIN file exists: {alt_bin_exists}")

        
        if idx_exists:
            print(f"IDX file permissions: {oct(os.stat(train_idx_path).st_mode)[-3:]}")
        if bin_exists:
            print(
                f"BIN file permissions: {oct(os.stat(self.train_path + '.bin').st_mode)[-3:]}"
            )

        
        if not idx_exists:
            raise FileNotFoundError(f"Dataset not found at {self.train_path}")

        if not self.single_file and not os.path.exists(self.valid_path + ".idx"):
            raise FileNotFoundError(
                f"Validation dataset not found at {self.valid_path}"
            )

    def setup(self, stage=None):
        """Set up train and validation datasets."""
        if stage == "fit" or stage is None:
            if self.single_file:
                temp_dataset = MMapIndexedDataset(self.train_path, skip_warmup=True)
                dataset_size = len(temp_dataset)
                train_size = int(dataset_size * self.train_valid_split)

                
                self.train_dataset = MMapPythiaDataset(
                    self.train_path,
                    start_idx=0,
                    end_idx=train_size - 1,
                    equiv_lookup=self.equiv_lookup,
                    mask_bigram_with_self=self.mask_bigram_with_self,
                    mask_bigram_loss=self.mask_bigram_loss,
                    mask_bigram_loss_with_self_included=self.mask_bigram_loss_with_self_included,
                )

                
                self.valid_dataset = MMapPythiaDataset(
                    self.train_path, start_idx=train_size, end_idx=dataset_size - 1
                )

                self.valid_dataset_with_masking = MMapPythiaDataset(
                    self.train_path, start_idx=train_size, end_idx=dataset_size - 1,
                    equiv_lookup=self.equiv_lookup,
                    mask_bigram_with_self=self.mask_bigram_with_self,
                    mask_bigram_loss=self.mask_bigram_loss,
                    mask_bigram_loss_with_self_included=self.mask_bigram_loss_with_self_included
                )

            else:
                self.train_dataset = MMapPythiaDataset(
                    self.train_path,
                    equiv_lookup=self.equiv_lookup,
                    mask_bigram_with_self=self.mask_bigram_with_self,
                    mask_bigram_loss=self.mask_bigram_loss,
                    mask_bigram_loss_with_self_included=self.mask_bigram_loss_with_self_included,
                )
                self.valid_dataset = MMapPythiaDataset(self.valid_path)
                self.valid_dataset_with_masking = MMapPythiaDataset(self.valid_path,
                    equiv_lookup=self.equiv_lookup,
                    mask_bigram_with_self=self.mask_bigram_with_self,
                    mask_bigram_loss=self.mask_bigram_loss,
                    mask_bigram_loss_with_self_included=self.mask_bigram_loss_with_self_included
                )

    def train_dataloader(self):
        """Return training dataloader."""
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=self.shuffle,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False,
        )

    def val_dataloader(self):
        """Return validation dataloader."""
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False,
        )
    def val_dataloader_with_masking(self):
        """Return validation dataloader with masking."""
        return DataLoader(
            self.valid_dataset_with_masking,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False,
        )


def setup_pythia_data(config, equiv_lookup=None):
    """Factory function to set up data module for training with detailed debugging."""
    
    train_path = config.get(
        "train-data-paths", ["/fsx/pile/pile_20B_tokenizer_text_document"]
    )[0]
    valid_path = config.get(
        "valid-data-paths", ["/fsx/pile/pile_20B_tokenizer_text_document"]
    )[0]

    
    print(f"==== SETUP DATA DEBUG INFO ====")
    print(f"Train path from config: {train_path}")
    print(f"Valid path from config: {valid_path}")
    print(f"Path type: {type(train_path)}")

    
    train_path = train_path.strip()
    valid_path = valid_path.strip()

    
    for ext in [".bin", ".idx"]:
        if train_path.endswith(ext):
            print(f"⚠️ Removing {ext} extension from train_path")
            train_path = train_path[: -len(ext)]
        if valid_path.endswith(ext):
            print(f"⚠️ Removing {ext} extension from valid_path")
            valid_path = valid_path[: -len(ext)]

    print(f"Clean train path: {train_path}")
    print(f"Clean valid path: {valid_path}")

    
    if os.path.exists(train_path + ".idx"):
        print(f"✅ Verified IDX file exists at {train_path}.idx")
    else:
        print(f"❌ IDX file NOT found at {train_path}.idx")

    if os.path.exists(train_path + ".bin"):
        print(f"✅ Verified BIN file exists at {train_path}.bin")
    else:
        print(f"❌ BIN file NOT found at {train_path}.bin")
    print(f"================================")

    
    single_file = config.get("use_single_file", True)
    shuffle = config.get("shuffle", False)

    data_module = MMapPythiaDataModule(
        train_path=train_path,
        valid_path=valid_path if not single_file else None,
        batch_size=config.get("train_micro_batch_size_per_gpu", 16),
        num_workers=config.get("num_workers", 4),
        train_valid_split=config.get("train_valid_split", 0.95),
        single_file=single_file,
        shuffle=shuffle,
        equiv_lookup=equiv_lookup,
        mask_bigram_with_self=config.get("mask_bigram_loss_with_self_included", False),
        mask_bigram_loss=config["mask_bigram_loss"],
        mask_bigram_loss_with_self_included=config["mask_bigram_loss_with_self_included"],
    )

    return data_module
