from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
import numpy as np


class StateBuffer:
    def __init__(self, capacity: int, obs_dim: int):
        self.obs = np.zeros((capacity, obs_dim), dtype=np.float32)
        self.ptr = 0
        self.full = False

    def add(self, obs: np.ndarray):
        self.obs[self.ptr] = obs
        self.ptr = (self.ptr + 1) % len(self.obs)
        if self.ptr == 0:
            self.full = True

    def sample(self, n: int) -> np.ndarray:
        size = len(self)
        idx = np.random.randint(0, size, size=min(n, size))
        return self.obs[idx]

    def __len__(self):
        return len(self.obs) if self.full else self.ptr


@dataclass
class Transition:
    s: np.ndarray
    a: np.ndarray
    r: float
    s2: np.ndarray
    d: float
    c: float


class TransitionDataset:
    def __init__(self, capacity: int, obs_dim: int, act_dim: int):
        self.s = np.zeros((capacity, obs_dim), dtype=np.float32)
        self.a = np.zeros((capacity, act_dim), dtype=np.float32)
        self.r = np.zeros((capacity, 1), dtype=np.float32)
        self.s2 = np.zeros((capacity, obs_dim), dtype=np.float32)
        self.d = np.zeros((capacity, 1), dtype=np.float32)
        self.c = np.zeros((capacity, 1), dtype=np.float32)
        self.ptr = 0
        self.full = False

    def add(self, tr: Transition):
        i = self.ptr
        self.s[i] = tr.s
        self.a[i] = tr.a
        self.r[i] = tr.r
        self.s2[i] = tr.s2
        self.d[i] = tr.d
        self.c[i] = tr.c
        self.ptr = (self.ptr + 1) % len(self.s)
        if self.ptr == 0:
            self.full = True

    def sample(self, n: int) -> Dict[str, np.ndarray]:
        size = len(self)
        idx = np.random.randint(0, size, size=min(n, size))
        return {
            "obs": self.s[idx],
            "act": self.a[idx],
            "rew": self.r[idx],
            "next_obs": self.s2[idx],
            "done": self.d[idx],
            "cost": self.c[idx],
        }

    def __len__(self):
        return len(self.s) if self.full else self.ptr


class MiniBatchIterator:
    def __init__(self, data: Dict[str, np.ndarray], batch_size: int, shuffle: bool = True):
        self.data = data
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.size = len(next(iter(data.values())))
        self.indices = np.arange(self.size)
        if self.shuffle:
            np.random.shuffle(self.indices)
        self.pos = 0

    def __iter__(self) -> Iterator[Dict[str, np.ndarray]]:
        self.pos = 0
        if self.shuffle:
            np.random.shuffle(self.indices)
        return self

    def __next__(self) -> Dict[str, np.ndarray]:
        if self.pos >= self.size:
            raise StopIteration
        idx = self.indices[self.pos : self.pos + self.batch_size]
        self.pos += self.batch_size
        return {k: v[idx] for k, v in self.data.items()}


def normalize_dataset(data: Dict[str, np.ndarray]) -> Tuple[Dict[str, np.ndarray], Dict[str, Tuple[np.ndarray, np.ndarray]]]:
    stats: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
    out: Dict[str, np.ndarray] = {}
    for k, v in data.items():
        if v.ndim == 2 and v.shape[1] > 1:                             
            m = v.mean(axis=0, keepdims=True)
            s = v.std(axis=0, keepdims=True) + 1e-6
            out[k] = (v - m) / s
            stats[k] = (m, s)
        else:
            out[k] = v.copy()
    return out, stats


def denormalize_array(x: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    return x * std + mean


def _demo():
    ds = TransitionDataset(capacity=100, obs_dim=4, act_dim=2)
    for i in range(30):
        tr = Transition(s=np.random.randn(4).astype(np.float32), a=np.random.randn(2).astype(np.float32), r=float(np.random.randn()), s2=np.random.randn(4).astype(np.float32), d=float(np.random.rand() > 0.9), c=float(np.random.rand() < 0.1))
        ds.add(tr)
    batch = ds.sample(16)
    it = MiniBatchIterator(batch, batch_size=8)
    for b in it:
        _ = b
    normed, stats = normalize_dataset(batch)
    print(len(ds), normed.keys(), list(stats.keys()))


if __name__ == "__main__":
    _demo()
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
