import gzip
import json
from typing import List
import torch


class JsonlDataset(torch.utils.data.IterableDataset):
    def __init__(
        self,
        file_paths: List[str],
        start_lines=None,
    ):
        """
        Initialize the data loader.

        :param file_paths: List of paths to .jsonl or .jsonl.gz files.
        :param batch_size: Number of lines per batch.
        """
        super().__init__()

        self.file_paths = file_paths
        self.start_lines = start_lines

        assert self.start_lines is None or len(self.start_lines) == len(file_paths)

    def _open_file(self, file_path: str):
        """
        Open a single file, handling gzip or regular .jsonl files.
        """
        if file_path.endswith(".gz"):
            return gzip.open(file_path, "rt", encoding="utf-8")
        return open(file_path, "r", encoding="utf-8")

    def __iter__(self):
        """
        Yield batches of JSON lines from all files.
        """
        for idx, file_path in enumerate(self.file_paths):
            current_line = 0  # Tracks global line number across files
            start_line = self.start_lines[idx] if self.start_lines else 0
            with self._open_file(file_path) as f:
                for line in f:
                    # Skip lines until reaching the start line
                    if current_line < start_line:
                        current_line += 1
                        continue
                    # Parse the JSON line
                    data = json.loads(line.strip())
                    yield data


class JsonDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        file_paths: List[str],
        start_line=0,
        end_line=-1,
    ):
        """
        Initialize the data loader.

        :param file_paths: List of paths to .json files.
        """
        super().__init__()

        self.file_paths = file_paths

        self.data = []
        for file_path in self.file_paths:
            with open(file_path, "r") as f:
                cur_data = json.load(f)
                if end_line > 0:
                    self.data.extend(cur_data[start_line:end_line])
                else:
                    self.data.extend(cur_data[start_line:])

    def __getitem__(self, idx):
        """
        Yield batches of JSON lines from all files.
        """
        return self.data[idx]

    def __len__(self):
        return len(self.data)
