import os
import numpy as np
import jax
import jax.numpy as jnp
from typing import Any
# from learned_optimization.tasks.datasets.base import ThreadSafeIterator, LazyIterator
from learned_optimization.tasks.datasets.base import Datasets, ThreadSafeIterator, LazyIterator
# DataLoaderLite copied from fineweb_loading_test.py (with JAX output)
class DataLoaderLite:
    def __init__(self, B, T, process_rank, num_processes, split, data_root="data/fineweb_edu_10B"):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in {'train', 'val'}
        data_root = data_root
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        assert len(shards) > 0, f"no shards found for split {split}"
        self.shards = shards
        self.reset()

    def reset(self):
        self.current_shard = 0
        self.tokens = np.load(self.shards[self.current_shard], allow_pickle=True).astype(np.int32)
        self.current_position = self.B * self.T * self.process_rank

    def __iter__(self):
        return self

    def __next__(self):
        return self.next_batch()

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = np.reshape(buf[:-1], (B, T))
        y = np.reshape(buf[1:], (B, T))
        self.current_position += B * T * self.num_processes
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = np.load(self.shards[self.current_shard], allow_pickle=True).astype(np.int32)
            self.current_position = B * T * self.process_rank
        # Convert to jax arrays and put on device
        x = jax.device_put(jnp.array(x), jax.devices('gpu')[jax.process_index()])
        y = jax.device_put(jnp.array(y), jax.devices('gpu')[jax.process_index()])
        return {'image': x, 'label': y}

def make_fineweb_datasets(
    batch_size=[2, 1, 1, 1],
    sequence_length=64,
    prefetch_batches=[1, 1, 1, 1],
    process_rank=0,
    num_processes=1,
    batch_shape=None,
    data_root="data/fineweb_edu_10B",
    **kwargs
):
    """Create a Datasets object for FineWeb, with JAX arrays and GPT2 vocab size."""

    print("Making fineweb datasets with args:")
    print(f"batch_size: {batch_size}")
    print(f"sequence_length: {sequence_length}")
    print(f"prefetch_batches: {prefetch_batches}")
    print(f"process_rank: {process_rank}")
    print(f"num_processes: {num_processes}")
    print(f"batch_shape: {batch_shape}")

    splits = ['train', 'train', 'train', 'val']  # train, inner_valid, outer_valid, test
    split_names = ['train', 'inner_valid', 'outer_valid', 'test']
    split_map = {
        'train':splits[0],
        'inner_valid':splits[1],
        'outer_valid':splits[2],
        'test':splits[3],
    }
    batch_shape_map = {
        'train': batch_shape if batch_shape is not None else (batch_size[0],),
        'inner_valid': (batch_size[1],),
        'outer_valid': (batch_size[2],),
        'test': (batch_size[3],),
    }
    assert len(splits) == len(batch_size) == len(prefetch_batches)
    vocab_size = 50257  # GPT2 vocab size

    def make(split, B, T, batch_shape):
        def iterator_fn():
            loader = DataLoaderLite(B, T, process_rank, num_processes, split_map[split], data_root=data_root)
            for batch in loader:
                # Reshape to batch_shape if provided
                if batch_shape is not None:
                    shape = batch_shape_map[split] + (T,)
                    batch = {
                        'image': jnp.reshape(batch['image'], shape),
                        'label': jnp.reshape(batch['label'], shape),
                    }
                yield batch
        return ThreadSafeIterator(LazyIterator(iterator_fn))

    # Determine batch_shape for each split
    def get_batch_shape(bs):
        if batch_shape is not None:
            return batch_shape
        else:
            return (bs,)

    iters = [make(split_names[i], batch_size[i], sequence_length, get_batch_shape(batch_size[i])) for i in range(4)]
    abstract_batch = {
        'image': jax.core.ShapedArray((1, sequence_length), jnp.int32),
        'label': jax.core.ShapedArray((1, sequence_length), jnp.int32),
    }
    extra_info = {
        'vocab_size': vocab_size,
        'vocab': None,
        'name': f'fineweb-s{sequence_length}-gpt2',
    }
    return Datasets(
        train=iters[0],
        inner_valid=iters[1],
        outer_valid=iters[2],
        test=iters[3],
        extra_info=extra_info,
        abstract_batch=abstract_batch,
    )

if __name__ == "__main__":
    # Example settings for hierarchical batch shape
    batch_shape = (2, 3, 4, 5)  # (perturbations, workers, local_steps, local_batch_size)
    sequence_length = 8
    batch_size = [2*3*4*5, 3*4*5, 4*5, 5]  # train, inner_valid, outer_valid, test
    prefetch_batches = [1, 1, 1, 1]
    process_rank = 0
    num_processes = 1

    ds = make_fineweb_datasets(
        batch_size=batch_size,
        sequence_length=sequence_length,
        prefetch_batches=prefetch_batches,
        process_rank=process_rank,
        num_processes=num_processes,
        batch_shape=batch_shape,
    )

    split_names = ["train", "inner_valid", "outer_valid", "test"]
    splits = [ds.train, ds.inner_valid, ds.outer_valid, ds.test]

    for name, split in zip(split_names, splits):
        it = iter(split)
        batch = next(it)
        print(f"Split: {name}")
        for k, v in batch.items():
            print(f"  {k}: shape={v.shape}, dtype={v.dtype}, type={type(v)})")
            print(v)
        print("-")
        break
    print("FineWeb DataLoader test complete.") 