from collections.abc import Iterable, Iterator
from typing import Generic, NamedTuple, TypeVar
import numpy as np
from numpy.random import Generator
from numpy.typing import NDArray

from offline.types import BoolArray, IntArray


T = TypeVar("T", bound=NamedTuple)
S = TypeVar("S", bound=np.generic)


def collate(batch: Iterable[tuple[T, int]]) -> tuple[T, tuple[int, ...]]:
    data: tuple[T, ...]
    lengths: tuple[int, ...]
    data, lengths = zip(*batch)  # type: ignore
    return (
        data[0].__class__._make(np.stack(fields) for fields in zip(*data)),
        lengths,
    )


def get_length(data: NamedTuple) -> int:
    lengths = set(field.shape[0] for field in data)
    if len(lengths) > 1:
        raise ValueError("Size mismatch between fields.")
    return lengths.pop()


def pad_data(
    field: NDArray[S],
    start_indices: IntArray,
    filter_trajectories: bool = False,
) -> NDArray[S]:
    max_length = np.amax(start_indices[1:] - start_indices[:-1])
    list_outputs = []
    for start_index, end_index in zip(start_indices[:-1], start_indices[1:]):
        length = end_index - start_index
        if filter_trajectories and length == 1:
            continue
        output = np.zeros((max_length,) + field.shape[1:], dtype=field.dtype)
        output[:length] = field[start_index:end_index]
        list_outputs.append(output)
    return np.stack(list_outputs)


class Dataset(Generic[T]):
    def __init__(self, data: T):
        self.data = data
        self.length = get_length(data)

    def __getitem__(self, index) -> T:
        return self.data.__class__._make(field[index] for field in self.data)

    def __iter__(self) -> Iterator[T]:
        for index in range(self.length):
            yield self[index]

    def __len__(self) -> int:
        return self.length

    def sample(self, rng: Generator, batch_size: int):
        indices = rng.integers(self.length, size=batch_size)
        return self[indices]


class TrajectoryDataset(Generic[T]):
    def __init__(
        self,
        data: T,
        dones: BoolArray,
        filter_trajectories: bool = False,
    ):
        length = get_length(data)
        if length != np.size(dones):
            raise ValueError("Size mismatch between data and endpoints.")
        start_indices = np.insert(np.flatnonzero(dones) + 1, 0, 0)
        self.trajectory_lengths = start_indices[1:] - start_indices[:-1]
        if filter_trajectories:
            self.trajectory_lengths = self.trajectory_lengths[
                self.trajectory_lengths > 1
            ]
        self.data = data._make(
            pad_data(field, start_indices, filter_trajectories)
            for field in data
        )
        self.length = get_length(self.data)
        assert self.length == self.trajectory_lengths.size

    def __getitem__(self, index: int) -> tuple[T, int]:
        length = self.trajectory_lengths[index]
        return (
            self.data.__class__._make(field[index] for field in self.data),
            length,
        )

    def __iter__(self) -> Iterator[tuple[T, int]]:
        for index in range(self.length):
            yield self[index]

    def __len__(self) -> int:
        return self.length


class DataLoader(Generic[T]):
    def __init__(
        self,
        dataset: Dataset[T],
        batch_size: int,
        drop_last: bool,
        repeat: int = 1,
        rng: Generator | None = None,
    ):
        self.batch_size = batch_size
        self.dataset = dataset
        self.drop_last = drop_last
        self.repeat = repeat
        self.rng = rng

    def __len__(self) -> int:
        length = len(self.dataset)
        if not self.drop_last:
            length += self.batch_size - 1
        return (length // self.batch_size) * self.repeat

    def __iter__(self) -> Iterator[T]:
        size = len(self.dataset)
        for _ in range(self.repeat):
            if self.rng is None:
                indices = None
            else:
                indices = self.rng.permutation(size)
            batch_indices = list(range(0, size + 1, self.batch_size))
            if not self.drop_last and batch_indices[-1] != size:
                batch_indices = batch_indices + [size]
            for i, j in zip(batch_indices[:-1], batch_indices[1:]):
                if indices is None:
                    yield self.dataset[i:j]
                else:
                    yield self.dataset[indices[i:j]]

    def repeat_forever(self) -> Iterator[T]:
        while True:
            yield from self


class ArrayDataLoader(Generic[S]):
    def __init__(
        self,
        data: NDArray[S],
        batch_size: int,
        drop_last: bool,
        repeat: int = 1,
        rng: Generator | None = None,
    ):
        self.batch_size = batch_size
        self.data = data
        self.drop_last = drop_last
        self.repeat = repeat
        self.rng = rng

    def __len__(self) -> int:
        length = self.data.shape[0]
        if not self.drop_last:
            length += self.batch_size - 1
        return (length // self.batch_size) * self.repeat

    def __iter__(self) -> Iterator[NDArray[S]]:
        size = self.data.shape[0]
        for _ in range(self.repeat):
            if self.rng is None:
                indices = None
            else:
                indices = self.rng.permutation(size)
            batch_indices = list(range(0, size + 1, self.batch_size))
            if not self.drop_last and batch_indices[-1] != size:
                batch_indices = batch_indices + [size]
            for i, j in zip(batch_indices[:-1], batch_indices[1:]):
                if indices is None:
                    yield self.data[i:j]
                else:
                    yield self.data[indices[i:j]]

    def repeat_forever(self) -> Iterator[NDArray[S]]:
        while True:
            yield from self


class TrajectoryDataLoader(Generic[T]):
    def __init__(
        self,
        dataset: TrajectoryDataset[T],
        batch_size: int,
        drop_last: bool,
        repeat: int = 1,
        reweight: bool = False,
        rng: Generator | None = None,
    ):
        self.batch_size = batch_size
        self.dataset = dataset
        self.drop_last = drop_last
        self.repeat = repeat
        self.rng = rng
        self.indices = np.arange(len(self.dataset))
        if reweight:
            lengths = self.dataset.trajectory_lengths
            if np.max(lengths) != np.min(lengths):
                self.indices = np.repeat(self.indices, lengths)

    def __len__(self) -> int:
        if self.drop_last:
            length = self.indices.size // self.batch_size
        else:
            length = (
                self.indices.size + self.batch_size - 1
            ) // self.batch_size
        return length * self.repeat

    def __iter__(self) -> Iterator[tuple[T, tuple[int, ...]]]:
        size = self.indices.size
        for _ in range(self.repeat):
            indices = (
                self.indices
                if self.rng is None
                else self.rng.permutation(self.indices)
            )
            batch_indices = list(range(0, size + 1, self.batch_size))
            if not self.drop_last and batch_indices[-1] != size:
                batch_indices = batch_indices + [size]

            for i, j in zip(batch_indices[:-1], batch_indices[1:]):
                yield collate((self.dataset[k] for k in indices[i:j]))

    def repeat_forever(self) -> Iterator[tuple[T, tuple[int, ...]]]:
        while True:
            yield from self
