from torch.utils.data import Dataset
import random
import torch.distributed as dist
import numpy as np

class Wikitext(Dataset):
    def __init__(self, base_dataset, with_idx=False, batch_size=8):
        self.base_dataset = base_dataset
        self.epoch = 0
        self.with_idx = with_idx
        self.batch_size = batch_size
    def __len__(self):
        return len(self.base_dataset)

class AQSGD_Wikitext(Wikitext):
    def __init__(self, base_dataset, with_idx=False, batch_size=8):
        super().__init__(base_dataset, with_idx, batch_size)
        
    def __getitem__(self, idx):
        item = self.base_dataset[idx]
        # if self.with_idx:
        return {
            'input_ids': item['input_ids'],
            'attention_mask': item['attention_mask'] if 'attention_mask' in item else None,
            'labels': item['labels'] if 'labels' in item else item['input_ids'],
            'indices': idx
        }
        # return item

class LazyWikitext(Wikitext):
    def __init__(self, base_dataset, p_t=None, batch_size=8, with_idx=False, pipeline_parallel=False):
        super().__init__(base_dataset, with_idx, batch_size)
        self.p_t = p_t if p_t is not None else lambda x: 1
        self.prev_indices = [0 for _ in range(batch_size)]
        self.cnt = 0
        self.pipeline_parallel = pipeline_parallel
        
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        
        if pipeline_parallel:
            self.rng = np.random.RandomState(42)
        else:
            self.rng = np.random.RandomState(42 + self.rank)
        
    def update_epoch(self, epoch):
        self.epoch = epoch
        self.cnt = 0
        if self.pipeline_parallel:
            self.rng = np.random.RandomState(42 + epoch)
        else:
            self.rng = np.random.RandomState(42 + epoch + self.rank)

    def __getitem__(self, idx):
        original_idx = idx
        self.cnt += 1
        current_batch = (self.cnt - 1) // self.batch_size + 1
        idx = self.cnt % self.batch_size

        if self.pipeline_parallel:
            if current_batch == 1 and self.epoch == 0:
                self.prev_indices[idx] = self.rng.randint(0, len(self.base_dataset))
            else:
                if self.rng.random() < self.p_t(current_batch):
                    self.prev_indices[idx] = self.rng.randint(0, len(self.base_dataset))
        else:
            num_samples = len(self.base_dataset)
            samples_per_rank = num_samples // self.world_size
            rank_start = self.rank * samples_per_rank
            rank_end = rank_start + samples_per_rank if self.rank != self.world_size - 1 else num_samples
            
            if current_batch == 1 and self.epoch == 0:
                self.prev_indices[idx] = self.rng.randint(rank_start, rank_end)
            else:
                if self.rng.random() < self.p_t(current_batch):
                    self.prev_indices[idx] = self.rng.randint(rank_start, rank_end)
        
        # Get item from base dataset
        if self.with_idx:
            return {
                'input_ids': self.base_dataset[self.prev_indices[idx]]['input_ids'],
                'attention_mask': self.base_dataset[self.prev_indices[idx]]['attention_mask'] if 'attention_mask' in self.base_dataset[self.prev_indices[idx]] else None,
                'labels': self.base_dataset[self.prev_indices[idx]]['labels'] if 'labels' in self.base_dataset[self.prev_indices[idx]] else self.base_dataset[self.prev_indices[idx]]['input_ids'],
                'indices': self.prev_indices[idx]
            }
        else:
            return self.base_dataset[self.prev_indices[idx]]

class LargeBatchWikitext(Wikitext):
    def __init__(self, base_dataset, k=1, batch_size=8, with_idx=False, pipeline_parallel=False):
        super().__init__(base_dataset, with_idx, batch_size)
        self.k = k  # Number of batches before resampling
        self.prev_indices = [0 for _ in range(batch_size)]
        self.cnt = 0
        self.pipeline_parallel = pipeline_parallel
        
        # Get current process rank and world size
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        
        if pipeline_parallel:
            # Pipeline parallel mode: all ranks use same RNG
            self.rng = np.random.RandomState(42)
        else:
            # Non-pipeline parallel mode: different seed per rank
            self.rng = np.random.RandomState(42 + self.rank)
            
    def update_epoch(self, epoch):
        self.epoch = epoch
        self.cnt = 0
        if self.pipeline_parallel:
            self.rng = np.random.RandomState(42 + epoch)
        else:
            self.rng = np.random.RandomState(42 + epoch + self.rank)

    def __getitem__(self, idx):
        self.cnt += 1
        current_batch = (self.cnt - 1) // self.batch_size + 1
        idx = self.cnt % self.batch_size

        # When k >= 1, resample every k batches
        # When k < 1, use k as probability to resample
        should_resample = (
            (self.k >= 1 and (current_batch - 1) % self.k == 0) or
            (self.k < 1 and self.rng.random() < self.k)
        )

        if should_resample:
            if self.pipeline_parallel:
                # All ranks sample from entire dataset
                self.prev_indices = [self.rng.randint(0, len(self.base_dataset)) 
                                   for _ in range(self.batch_size)]
            else:
                # Each rank samples from its partition
                samples_per_rank = len(self.base_dataset) // self.world_size
                rank_start = self.rank * samples_per_rank
                rank_end = rank_start + samples_per_rank if self.rank != self.world_size - 1 else len(self.base_dataset)
                self.prev_indices = [self.rng.randint(rank_start, rank_end) 
                                   for _ in range(self.batch_size)]

        if self.with_idx:
            return {
                'input_ids': self.base_dataset[self.prev_indices[idx]]['input_ids'],
                'attention_mask': self.base_dataset[self.prev_indices[idx]]['attention_mask'] if 'attention_mask' in self.base_dataset[self.prev_indices[idx]] else None,
                'labels': self.base_dataset[self.prev_indices[idx]]['labels'] if 'labels' in self.base_dataset[self.prev_indices[idx]] else self.base_dataset[self.prev_indices[idx]]['input_ids'],
                'indices': self.prev_indices[idx]
            }
        else:
            return self.base_dataset[self.prev_indices[idx]]
