from __future__ import annotations

import os
import re
from collections import defaultdict
import torch
import numpy as np
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Sequence, Union

from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset, Subset
from torch.utils.data.distributed import DistributedSampler

from src.run.distributed import get_rank, get_world_size


class TokenDataset(Dataset):
    """PyTorch Dataset for memory-mapped token files."""

    def __init__(self, filename: str, T: int):
        self.filename = filename
        self.T = T

        # Load tokens as memory map
        self.tokens = np.memmap(filename, dtype=np.uint16, mode="r")
        self.num_tokens = len(self.tokens)

        # The number of sequences that can be formed
        self.num_sequences = (self.num_tokens - 1) // self.T

    def __len__(self) -> int:
        return self.num_sequences

    def __getitem__(self, idx: int) -> torch.Tensor:

        # Get sequence of T tokens if _labeled to not repeat label tokens at the end
        if "_labeled" in self.filename:
            buf = self.tokens[idx * self.T : (idx + 1) * self.T]
            
        # Get sequence of T+1 tokens otherwise
        else:
            buf = self.tokens[idx * self.T : (idx + 1) * self.T + 1]

        # Convert to tensor
        return torch.from_numpy(buf.astype(np.int64))


class DataLoader(ABC):
    """Abstract base class for data loaders."""
    
    def __init__(self, B: int, T: int):
        self.B = B
        self.T = T
    
    @abstractmethod
    def reset(self, epoch: int) -> None:
        """Reset the data loader."""
        pass
    
    @abstractmethod
    def next_batch(self) -> tuple[torch.Tensor, str | None]:
        """Get the next batch of data."""
        pass
    
    @abstractmethod
    def __len__(self) -> int:
        """Return the number of batches."""
        pass


class SingleDataLoader(DataLoader):

    def __init__(
        self,
        filename: str,
        B: int,
        T: int,
        process_rank: int = 0,
        num_processes: int = 1,
        num_workers: int = 0,
        device: Optional[Union[str, torch.device]] = None,
        pin_memory: bool = True,
        persistent_workers: bool = True,
        label: str | None = None,
        seed: int = 42,
        drop_last: bool = True,
        start: int = 0,
        end: int = -1,
    ):
        self.B = B
        self.T = T
        self.device = device
        self.filename = filename
        self.label = label
        self.seed = seed

        base_dataset = TokenDataset(filename, T)
        
        # Apply partition if start/end are specified.
        if start != 0 or end != -1:

            dataset_size = len(base_dataset)

            # Calculate total number of batches (accounting for drop_last)
            # we need to scale the total number of batches by num_processes because the sampler will distribute the batches evenly across the ranks
            if drop_last:
                total_batches = dataset_size // (B * num_processes)
            else:
                total_batches = (dataset_size + (B * num_processes) - 1) // (B * num_processes)

            # start and end are in batch indices - normalize negative indices
            if start < 0:
                start += total_batches

            if end < 0:
                end += total_batches + 1

            assert 0 <= start < total_batches - 1, f"start out of range 0, {total_batches - 1}"
            assert 0 <= end < total_batches + 1, f"end out of range 0, {total_batches}"
            
            # Convert batch indices to sequence indices
            # similarly as above, we need to scale the sequence indices by num_processes because the sampler will distribute the batches evenly across the ranks
            start_seq = start * B * num_processes
            end_seq = end * B * num_processes
            
            # Define indices selector
            indices = list(range(start_seq, end_seq))
            self.dataset = Subset(base_dataset, indices)

        else:
            self.dataset = base_dataset

        self.sampler = DistributedSampler(
            self.dataset,
            num_replicas=num_processes,
            rank=process_rank,
            shuffle=True,
            seed=seed,
            drop_last=True,
        )

        self.dataloader = TorchDataLoader(
            self.dataset,
            batch_size=B,
            sampler=self.sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers and num_workers > 0,
            drop_last=drop_last,
        )

        self.iterator = iter(self.dataloader)

        self.last_epoch = 0

    def reset(self, epoch: int = 0) -> None:
        """Resets the data iterator."""
        self.sampler.set_epoch(epoch)
        self.iterator = iter(self.dataloader)
        self.last_epoch = epoch

    def next_batch(self) -> tuple[torch.Tensor, str | None]:
        """Gets the next batch, automatically resetting the iterator if it's exhausted."""
        try:
            batch = next(self.iterator)

        except StopIteration:
            self.reset(self.last_epoch + 1)
            batch = next(self.iterator)

        if self.device is not None:
            batch = batch.to(self.device, non_blocking=True)

        return batch, self.label

    def __len__(self) -> int:
        return len(self.dataloader)

    #implement the partition method
    #returns a new instance of SingleDataLoader
    #but such that the underlying dataset is only the subset of the data from start to end
    def partition(self, start: int, end: int) -> DataLoader:
        """Create a new SingleDataLoader with a subset of batches from start to end.
        
        Args:
            start: Start index (inclusive) relative to the original dataset
            end: End index (exclusive) relative to the original dataset
        """
        
        # Create new instance with partition parameters
        return SingleDataLoader(
            filename=self.filename,
            B=self.B,
            T=self.T,
            process_rank=self.sampler.rank,
            num_processes=self.sampler.num_replicas,
            num_workers=self.dataloader.num_workers,
            device=self.device,
            pin_memory=self.dataloader.pin_memory,
            persistent_workers=self.dataloader.persistent_workers,
            label=self.label,
            seed=self.sampler.seed,
            drop_last=self.dataloader.drop_last,
            start=start,
            end=end,
        )


class InterleavedDataLoader(DataLoader):
    """Randomly interleave an arbitrary number of *SingleDataLoader*
    instances while guaranteeing that no loader i`s selected more than once
    within any contiguous window of *N* batches, where *N* is the number of
    component loaders.  This is achieved by drawing batches according to a
    random permutation of the loader indices (sampling without replacement)
    and reshuffling once every loader has been used.  When reshuffling, we
    additionally ensure that the first loader in the new permutation is not
    the same as the last loader from the previous cycle, thereby preserving
    the non-repetition guarantee across cycle boundaries.
    """

    def __init__(self, loaders: Sequence[DataLoader], weighted: bool = True) -> None:

        assert len(loaders) > 0, "No loaders provided"

        # Validate that all loaders share the same batch (B) and sequence (T) sizes
        base_B, base_T = loaders[0].B, loaders[0].T
        for ld in loaders:
            assert ld.B == base_B and ld.T == base_T, "All loaders must have the same B and T values."
        
        super().__init__(base_B, base_T)
        self.loaders: List[DataLoader] = list(loaders)
        self.weighted = weighted

        # Internal state for proportional sampling. ``_remaining`` tracks how many
        # batches each loader still has left to contribute in the current cycle,
        seed = self.loaders[0].seed
        self.remaining: List[int] = []  # remaining batch counts per loader
        self.seed = seed
        self.rng = torch.Generator(device="cpu")
        self.rng.manual_seed(seed)
        self.last_epoch = 0
        self.reset()

    # ------------------------------------------------------------------ #
    # Public API                                                         #
    # ------------------------------------------------------------------ #

    def reset(self, epoch: int = 0) -> None:
        """Reset all component loaders and internal permutation state."""
        for ld in self.loaders:
            ld.reset(epoch)
        self.rng.manual_seed(self.seed + epoch) #ensure different order per epoch
        self.remaining = [len(ld) for ld in self.loaders]
        self.last_epoch = epoch

    def next_batch(self) -> tuple[torch.Tensor, str | None]:
        """Return the next batch according to the random interleaving schedule.
        
        For distributed training, we use torch RNG which is synchronized across ranks
        via the manual_seed set in run configuration. This ensures all ranks sample
        from the same loader at the same step, preventing NCCL deadlocks.
        """
        if sum(self.remaining) == 0:
            self.reset(self.last_epoch + 1)

        # Select loader
        if self.weighted:
            # Convert to probabilities based on counts
            probs = torch.tensor(self.remaining, dtype=torch.float32)
            probs = probs / probs.sum()
            # Sample
            idx = torch.multinomial(probs, 1, generator=self.rng).item()
        else:
            # Sample uniformly
            idx = torch.randint(0, len(self.loaders), (1,), generator=self.rng).item()

        # Update counters
        self.remaining[idx] -= 1

        # Get batch
        batch, label = self.loaders[idx].next_batch()
        return batch, label

    def __len__(self) -> int:
        return sum(len(ld) for ld in self.loaders)

    def partition(self, start: int, end: int) -> DataLoader:

        if start < 0:
            start += len(self)
        start = min(max(0, start), len(self))

        if end < 0:
            end += len(self) + 1
        end = min(max(0, end), len(self))
        
        start_prc = start / len(self)
        end_prc = end / len(self)

        new_loaders = []
        for loader in self.loaders:
            temp_start = int(len(loader) * start_prc)
            temp_end = int(len(loader) * end_prc)
            new_loaders.append(loader.partition(temp_start, temp_end))

        return InterleavedDataLoader(new_loaders)

# --------------------------------------------------------------------------- #
# Helper functions for auto-detecting categories and shards                   #
# --------------------------------------------------------------------------- #


def parse_bin_filename(filename: str) -> Optional[tuple[str, str, Optional[int]]]:
    """
    Parse a .bin filename into (category, split, shard_idx).
    
    Supports patterns:
    - category_train.bin -> ("category", "train", None)
    - category_test.bin -> ("category", "test", None)
    - category_train_00.bin -> ("category", "train", 0)
    - category_test_03.bin -> ("category", "test", 3)
    
    Args:
        filename: Filename to parse (e.g., "biology_train_00.bin")
    
    Returns:
        Tuple of (category, split, shard_idx) or None if pattern doesn't match
    """
    # Pattern: {category}_{split}[_{shard}].bin
    # Matches: biology_train.bin, biology_train_00.bin, biology_test.bin, etc.
    pattern = r'^(.+?)_(train|test)(?:_(\d+))?\.bin$'
    match = re.match(pattern, filename)
    
    if not match:
        return None
    
    category = match.group(1)
    split = match.group(2)
    shard_str = match.group(3)
    shard_idx = int(shard_str) if shard_str is not None else None
    
    return category, split, shard_idx


def get_bin_file_batch_count(bin_paths: list[Path], B: int, T: int, num_processes: int = 1) -> int:
    """
    Calculate the total number of batches that would be created from the given bin files.
    
    This mirrors the calculation in TokenDataset and SingleDataLoader to predict
    the batch count without actually creating the loaders.
    
    Args:
        bin_paths: List of .bin file paths
        B: Batch size
        T: Sequence length
        num_processes: Number of distributed processes (for DistributedSampler calculation)
    
    Returns:
        Total number of batches per rank
    """
    total_batches = 0
    for bin_path in bin_paths:
        # Get file size and calculate token count (same as TokenDataset)
        file_size = os.path.getsize(bin_path)
        num_tokens = file_size // 2  # uint16 = 2 bytes per token
        num_sequences = (num_tokens - 1) // T
        
        # Calculate batches accounting for distributed sampler (drop_last=True)
        # DistributedSampler divides by num_replicas
        batches_per_rank = num_sequences // (B * num_processes)
        total_batches += batches_per_rank
    
    return total_batches


def get_labels_batch_count(
    data_dirs: list[Path],
    labels: list[str],
    B: int,
    T: int,
    num_processes: int = 1,
) -> int:
    """
    Calculate the total batch count for a set of labels across data directories.
    
    This is useful for pre-computing core_batch_num when core_batch_limit=None
    but aux_batch_limit is a percentage that depends on it.
    
    Args:
        data_dirs: List of directories containing .bin files
        labels: List of category labels to count batches for
        B: Batch size
        T: Sequence length
        num_processes: Number of distributed processes
    
    Returns:
        Total number of batches per rank for the specified labels
    """
    # Collect all categories from all data directories
    categories = {}
    for data_dir in data_dirs:
        categories.update(auto_detect_categories(data_dir))
    
    # Gather all train bin files for the specified labels
    all_bin_paths = []
    for label in labels:
        if label in categories:
            all_bin_paths.extend(categories[label]["train"])
    
    return get_bin_file_batch_count(all_bin_paths, B, T, num_processes)


def auto_detect_categories(data_dir: Path) -> dict[str, dict[str, list[Path]]]:
    """
    Auto-detect all categories and their shards from .bin files in a directory.
    
    Args:
        data_dir: Directory containing .bin files
    
    Returns:
        Dict of {category: {"train": [shard_paths], "test": [shard_paths]}}
        Shard lists are sorted by shard index.
    
    Example:
        {
            "fineweb": {
                "train": [Path("fineweb_train_00.bin"), Path("fineweb_train_01.bin")],
                "test": [Path("fineweb_test.bin")]
            },
            "biology": {
                "train": [Path("biology_train.bin")],
                "test": [Path("biology_test.bin")]
            }
        }
    """
    categories = defaultdict(lambda: {"train": [], "test": []})
    
    # Find all .bin files
    bin_files = list(data_dir.glob("*.bin"))
    
    for bin_path in bin_files:
        parsed = parse_bin_filename(bin_path.name)
        if parsed is None:
            continue
        
        category, split, shard_idx = parsed
        
        # Store as (shard_idx, path) for sorting
        categories[category][split].append((shard_idx if shard_idx is not None else 0, bin_path))
    
    # Sort by shard index and extract just paths
    result = {}
    for category, splits in categories.items():
        result[category] = {
            "train": [path for _, path in sorted(splits["train"])],
            "test": [path for _, path in sorted(splits["test"])],
        }
    
    return result


def make_loaders(
    data_dirs: list[Path],
    aux_labels: list[str],
    core_labels: list[str] | None,
    B: int,
    T: int,
    seed: int,
    device: torch.device,
    core_batch_num: int | None,
    aux_batch_num: int | None,
    max_num_test: int | None = None,
) -> tuple[dict[str, DataLoader], list[str]]:
    """
    Args:
        data_dirs: List of directories containing .bin files
        aux_labels: List of auxiliary labels to create loaders for
        core_labels: List of core labels to create loaders for
        B: Batch size
        T: Sequence length
        seed: Random seed
        device: Device to load data to
        core_batch_num: Number of batches to sample from the core dataset
        aux_batch_num: Number of batches to sample from the auxiliary datasets
    
    Returns:
        Tuple of (loaders_dict, aux_labels_list)
        - loaders_dict: {label: {"train": loader, "test": loader}, ...}
        - aux_labels_list: List of detected aux category names
    """
    process_rank = get_rank()
    num_processes = get_world_size()

    categories = {}
    for data_dir in data_dirs:
        categories.update(auto_detect_categories(data_dir))

    all_labels = sorted(categories.keys())

    assert all(x in all_labels for x in aux_labels), f"aux_labels {set(aux_labels) - set(all_labels)} not found in all_labels"
    assert all(x in core_labels for x in core_labels), f"core_labels {set(core_labels) - set(all_labels)} not found in all_labels"

    loaders = {}

    for label in all_labels:

        loaders[label] = {}

        for split in ["train", "test", "test_ood"]:

            if split == "test_ood":
                shard_paths = categories[label]["test"]
            else:
                shard_paths = categories[label][split]
            
            assert len(shard_paths) > 0, f"No {split} data found for label {label}"

            is_aux = label in aux_labels
            is_core = label in core_labels

            batch_limit = -1
            if is_aux and aux_batch_num is not None:
                batch_limit = round(aux_batch_num / len(aux_labels))
            elif is_core and core_batch_num is not None:
                batch_limit = round(core_batch_num / len(core_labels))

            if max_num_test is not None and split == "test":
                # Handle -1 (no limit) case: use max_num_test directly
                if batch_limit == -1:
                    batch_limit = max_num_test
                else:
                    batch_limit = min(batch_limit, max_num_test)
            
            shard_loaders = []
            batch_count = 0
            for shard_path in shard_paths:
                
                loader = SingleDataLoader(
                    filename=str(shard_path),
                    B=B,
                    T=T if split != "test_ood" else 4*T,
                    process_rank=process_rank,
                    num_processes=num_processes,
                    label=label if is_aux else "core",
                    seed=seed,
                    device=device,
                    drop_last=True,
                )

                if batch_limit != -1 and batch_count + len(loader) > batch_limit:
                    end = batch_limit - batch_count
                    loader = loader.partition(0, end)
                    shard_loaders.append(loader)
                    break

                batch_count += len(loader)
                shard_loaders.append(loader)

            if len(shard_loaders) > 1:
                loaders[label][split] = InterleavedDataLoader(shard_loaders)
            else:
                loaders[label][split] = shard_loaders[0]

    #make output "core" and "all" loaders
    for agg_name, agg_labels in [("core", core_labels), ("all", core_labels + aux_labels)]:
        loaders[agg_name] = {}
        for split in ["train", "test", "test_ood"]:
            agg_loaders = [loaders[label][split] for label in agg_labels if split in loaders[label]]
            if len(agg_loaders) > 1:
                loaders[agg_name][split] = InterleavedDataLoader(agg_loaders)
            elif len(agg_loaders) == 1:
                loaders[agg_name][split] = agg_loaders[0]
    
    return loaders, core_labels
