from __future__ import annotations

from typing import Iterable, Iterator, List, Optional

import numpy as np
import torch
from torch.utils.data import IterableDataset

from src.data.tabular_sampler import TabularSampler
from src.utils import DataAttr


class OnlineTabularDataset(IterableDataset):
    """Finite-length iterable that generates tabular batches online via TabularSampler.

    Yields pre-batched DataAttr objects compatible with the existing training loop.

    Args:
        batch_size: number of tasks per batch (B)
        num_batches: number of batches per epoch (len(dataset))
        d_list: list of feature dims (D) to sample uniformly
        nc_list: list of context sizes (Nc) to sample uniformly
        num_buffer: fixed buffer size (Nb)
        num_target: fixed target size (Nt)
        normalize_x: whether to apply TabICL-style x normalization (fit on context)
        x_norm_method: normalization method (e.g., "power")
        x_outlier_threshold: outlier z-threshold for normalization
        normalize_y: z-normalize y using context stats
        dtype: torch dtype for generation
        device: device for generation ("cpu" recommended)
        seed: RNG seed for reproducibility (optional)
    """

    def __init__(
        self,
        *,
        batch_size: int,
        num_batches: int,
        d_list: List[int],
        nc_list: List[int],
        num_buffer: int,
        num_target: int,
        normalize_x: bool,
        x_norm_method: str,
        x_outlier_threshold: float,
        normalize_y: bool,
        dtype: torch.dtype,
        device: str,
        seed: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.batch_size = int(batch_size)
        self.num_batches = int(num_batches)
        self.d_list = list(map(int, d_list))
        self.nc_list = list(map(int, nc_list))
        self.nb = int(num_buffer)
        self.nt = int(num_target)
        self.normalize_x = bool(normalize_x)
        self.x_norm_method = str(x_norm_method)
        self.x_outlier_threshold = float(x_outlier_threshold)
        self.normalize_y = bool(normalize_y)
        self.dtype = dtype
        self.device = str(device)

        if seed is not None:
            torch.manual_seed(int(seed))
            np.random.seed(int(seed))

        # Single sampler that samples D uniformly from d_list per batch
        self.sampler = TabularSampler(
            dim_x=self.d_list,
            dim_y=1,
            is_causal=True,
            num_causes=None,
            num_layers=4,
            hidden_dim=64,
            noise_std=0.01,
            sampling="mixed",
            normalize_y=self.normalize_y,
            normalize_x=self.normalize_x,
            x_norm_method=self.x_norm_method,
            x_outlier_threshold=self.x_outlier_threshold,
            device=self.device,
            dtype=self.dtype,
        )

    def __len__(self) -> int:  # type: ignore[override]
        return self.num_batches

    def __iter__(self) -> Iterator[DataAttr]:  # type: ignore[override]
        # Iterate finite number of batches per epoch
        for _ in range(self.num_batches):
            # Uniformly sample Nc from the provided list; D sampling handled by sampler
            # Note: TabularSampler will pick one D per batch from its dim_x list.
            nc = int(self.nc_list[np.random.randint(0, len(self.nc_list))])
            batch = self.sampler.generate_batch(
                batch_size=self.batch_size,
                num_context=nc,
                num_buffer=self.nb,
                num_target=self.nt,
                context_range=None,
            )
            yield batch

