# Standard library
import math
import gc
import glob
from pathlib import Path
from typing import Literal

# Third party
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

# Local
import core


class ShadedDataset(Dataset):
    """
    A PyTorch Dataset that efficiently loads tokenized data from .npy shards,
    distributing tokens among DDP workers without excessive memory usage.
    """

    def __init__(
        self,
        shards_path: str,
        token_budget: int,
        sequence_length: int,
        rank: int,
        world_size: int,
        device: torch.device,
        shard_token_size: int = 1024**3,  # Expected tokens per .npy shard
        split: Literal["train", "validation"] = "train",
        pin_to_gpu: bool = False,
    ):
        """
        Args:
            shards_path (str): Path to the directory containing .npy token shards.
            token_budget (int): Total number of tokens to be used across all workers.
            sequence_length (int): The length of each sequence to be returned.
            rank (int): The rank of the current DDP process.
            world_size (int): The total number of DDP processes.
            device (torch.device): The CUDA device for this rank (e.g., torch.device("cuda:0")).
            shard_token_size (int): Expected number of tokens in each .npy shard file.
                                    Used to calculate how many shards to load for the budget.
            split (str): Train/validation split to load ("train" or "validation").
        """
        super().__init__()

        if split not in ["train", "validation"]:
            raise ValueError(
                f"Invalid split '{split}'. Must be either 'train' or 'validation'."
            )

        self.shards_path = Path(shards_path)
        self.token_budget = token_budget
        self.sequence_length = sequence_length
        self.rank = rank
        self.world_size = world_size
        self.device = device
        self.pin_to_gpu = pin_to_gpu
        self.shard_token_size = shard_token_size
        self.shard_filename_prefix = f"{split}_"

        if not self.shards_path.is_dir():
            raise FileNotFoundError(f"Shards directory not found: {self.shards_path}")

        # 1. Discover and sort shard files
        shard_files = sorted(
            glob.glob(str(self.shards_path / f"{self.shard_filename_prefix}*.npy"))
        )
        if not shard_files:
            raise FileNotFoundError(
                f"No shard files found with prefix '{self.shard_filename_prefix}' in {self.shards_path}"
            )

        # 2. Calculate how many shards to load
        core.logger.info(
            f"self.token_budget: {self.token_budget}, self.shard_token_size: {self.shard_token_size}"
        )
        num_shards_to_load = math.ceil(self.token_budget / self.shard_token_size)

        if num_shards_to_load > len(shard_files):
            core.logger.warning(
                f"[Rank {self.rank}]: Requested to load {num_shards_to_load} shards, "
                f"but only {len(shard_files)} are available. Using all available shards."
            )
            num_shards_to_load = len(shard_files)

        core.logger.info(
            f"[Rank {self.rank}] Loading {num_shards_to_load} shards incrementally..."
        )

        # 3. Load shards incrementally and extract worker's portion
        worker_token_chunks = []
        total_tokens_processed = 0
        tokens_per_worker = self.token_budget // self.world_size
        worker_start_token = self.rank * tokens_per_worker
        worker_end_token = (self.rank + 1) * tokens_per_worker

        if self.rank == self.world_size - 1:
            worker_end_token = self.token_budget

        for i, shard_file_path in enumerate(shard_files[:num_shards_to_load]):
            if total_tokens_processed >= self.token_budget:
                break

            try:
                shard_data_np = np.load(shard_file_path, allow_pickle=False).astype(np.int32)
                shard_tokens = len(shard_data_np)

                shard_start_global = total_tokens_processed
                shard_end_global = min(
                    total_tokens_processed + shard_tokens, self.token_budget
                )

                if (
                    shard_end_global > worker_start_token
                    and shard_start_global < worker_end_token
                ):
                    local_start = max(0, worker_start_token - shard_start_global)
                    local_end = min(shard_tokens, worker_end_token - shard_start_global)

                    if local_start < local_end:
                        worker_portion = shard_data_np[local_start:local_end]
                        worker_token_chunks.append(
                            torch.tensor(worker_portion, dtype=torch.long)
                        )
                        core.logger.debug(
                            f"[Rank {self.rank}] Shard {i}: extracted tokens {local_start}:{local_end} "
                            f"(global {shard_start_global + local_start}:{shard_start_global + local_end})"
                        )

                total_tokens_processed = shard_end_global
                del shard_data_np

            except Exception as e:
                raise IOError(f"Error loading shard file {shard_file_path}: {e}")

        # 4. Concatenate worker's token chunks
        if not worker_token_chunks:
            raise ValueError(
                f"[Rank {self.rank}] No tokens loaded for this worker. "
                "Check token budget and worker assignment."
            )

        worker_tokens_cpu = torch.cat(worker_token_chunks, dim=0)
        del worker_token_chunks
        gc.collect()

        # 5. Move to GPU if requested
        if self.pin_to_gpu:
            self.worker_tokens = worker_tokens_cpu.to(self.device)
            del worker_tokens_cpu
        else:
            self.worker_tokens = worker_tokens_cpu

        # Number of full sequences (samples)
        self.num_samples = len(self.worker_tokens) // self.sequence_length

        actual_tokens = len(self.worker_tokens)
        core.logger.info(
            f"[Rank {self.rank}] Loaded {actual_tokens:,} tokens, "
            f"creating {self.num_samples:,} samples of length {self.sequence_length}"
        )

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        start = idx * self.sequence_length
        end = start + self.sequence_length
        return self.worker_tokens[start:end]


def get_dataloader(
    dataset: ShadedDataset,
    batch_size: int,
    shuffle: bool = True,
) -> DataLoader:
    """Create DataLoader for ShadedDataset with appropriate memory settings."""
    if not isinstance(dataset, ShadedDataset):
        raise TypeError("dataset must be an instance of ShadedDataset")

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=not dataset.pin_to_gpu,
    )
