from __future__ import annotations

from typing import Dict, Iterable, List, Tuple

import numpy as np
import torch
from torch import nn, optim
from torch.nn.utils import clip_grad_norm_


def build_chunk_schedule(time_steps: int, parts: int) -> List[Tuple[int, int]]:
    """
    Build a list of (start, end) indices that matches the FPTT reference code:

        step = time_steps // parts
        _PARTS = parts (+1 if parts * step < time_steps)
        for p in range(_PARTS):
            start = p * step
            end   = min(time_steps, (p+1) * step)

    The returned schedule covers [0, time_steps) fully.
    """
    time_steps = int(time_steps)
    parts = max(1, int(parts))
    if time_steps <= 0:
        return [(0, 0)]
    step = max(1, time_steps // parts)
    total_parts = parts + (1 if parts * step < time_steps else 0)

    schedule: List[Tuple[int, int]] = []
    for idx in range(total_parts):
        start = idx * step
        if start >= time_steps:
            break
        end = min(time_steps, start + step)
        schedule.append((start, end))

    if not schedule:
        schedule = [(0, time_steps)]
    elif schedule[-1][1] < time_steps:
        # Safety guard; in practice this should not trigger with the logic above
        schedule.append((schedule[-1][1], time_steps))

    return schedule


class ClassOracleBuffer:
    """
    Class-conditional oracle buffer.

    This mirrors the *online* part of the official FPTT implementation where,
    for each class y and part p, we keep a prototype distribution over classes
    estimated from misclassified examples of that class at that part.
    """

    def __init__(
        self,
        num_classes: int,
        max_parts: int,
        momentum: float = 1.0,
    ) -> None:
        """
        Args:
            num_classes: number of classes.
            max_parts:   maximum number of parts (chunks).
            momentum:    interpolation factor when updating stored distributions.
                          momentum=1.0 is closest to the reference implementation.
        """
        self.num_classes = int(num_classes)
        self.momentum = float(np.clip(momentum, 1e-4, 1.0))
        self._storage = np.full(
            (self.num_classes, max_parts, self.num_classes),
            1.0 / float(self.num_classes),
            dtype=np.float32,
        )

    def ensure(self, required_parts: int) -> None:
        """Ensure that at least `required_parts` parts are available."""
        if required_parts <= self._storage.shape[1]:
            return
        extra = required_parts - self._storage.shape[1]
        filler = np.full(
            (self.num_classes, extra, self.num_classes),
            1.0 / float(self.num_classes),
            dtype=np.float32,
        )
        self._storage = np.concatenate([self._storage, filler], axis=1)

    def get(self, labels: np.ndarray, idx: int) -> np.ndarray:
        """
        Fetch oracle distributions for a batch of labels at a given part index.

        Args:
            labels: 1D array of class indices (batch,).
            idx:    part index.

        Returns:
            Array of shape (classes, batch) with per-example oracle probs.
        """
        if labels.ndim != 1:
            raise ValueError("labels must be a 1D array of class indices.")
        labels = labels.astype(np.int64)
        idx = int(min(idx, self._storage.shape[1] - 1))
        oracle = self._storage[labels, idx]  # (batch, classes)
        return oracle.T  # (classes, batch)

    def update(
        self,
        labels: np.ndarray,
        idx: int,
        probs: np.ndarray,
        preds: np.ndarray,
    ) -> None:
        """
        Update the class-conditional oracle.

        对齐官方 train() 里的逻辑（简化版）：

            filled_class = [0] * n_classes
            for j in range(B):
                y = target[j]
                if filled_class[y] == 0 and (argmax(prob_out[j]) != target[j]):
                    estimate_class_distribution[y, p] = prob_out[j]
                    filled_class[y] = 1

        这里额外允许一个 momentum 做平滑；momentum=1.0 时就是原始行为。

        Args:
            labels: (batch,) class indices.
            idx:    part index.
            probs:  (classes, batch) softmax probabilities for this chunk.
            preds:  (batch,) predicted class indices.
        """
        if probs.ndim != 2:
            raise ValueError("Expected probs with shape (classes, batch).")
        labels = labels.astype(np.int64)
        preds = preds.astype(np.int64)
        idx = int(min(idx, self._storage.shape[1] - 1))

        filled = np.zeros(self.num_classes, dtype=bool)
        for col, (y, y_hat) in enumerate(zip(labels, preds)):
            if y < 0 or y >= self.num_classes:
                continue
            if filled[y]:
                continue
            if y_hat == y:
                continue  # only use misclassified examples

            current = probs[:, col]  # (classes,)
            if self.momentum >= 0.999:
                new = current
            else:
                old = self._storage[y, idx]
                new = (1.0 - self.momentum) * old + self.momentum * current
            self._storage[y, idx] = new
            filled[y] = True


class OracleBufferStore:
    """
    Global oracle buffers keyed by dataset id.

    This mimics the global `estimate_class_distribution` tensor used in the
    reference code, but allows re-use across multiple model instances.
    """

    _buffers: Dict[str, ClassOracleBuffer] = {}

    @classmethod
    def get(
        cls,
        key: str,
        num_classes: int,
        max_parts: int,
        momentum: float,
    ) -> ClassOracleBuffer:
        buf = cls._buffers.get(key)
        if buf is None or buf.num_classes != num_classes:
            buf = ClassOracleBuffer(num_classes, max_parts, momentum)
            cls._buffers[key] = buf
        else:
            buf.momentum = float(np.clip(momentum, 1e-4, 1.0))
            buf.ensure(max_parts)
        return buf

    @classmethod
    def reset(cls, key: str | None = None) -> None:
        if key is None:
            cls._buffers.clear()
        else:
            cls._buffers.pop(key, None)


class FPTTRegularizer:
    """
    Implements the alpha/beta/rho proximal dynamics from the FPTT paper.

    直接对应官方实现非 debias 分支的：

        post_optimizer_updates(...)
        get_regularizer_named_params(..., _lambda)
    """

    def __init__(
        self,
        named_params: Iterable[Tuple[str, nn.Parameter]],
        alpha: float,
        beta: float,
        rho: float,
        lmbda: float = 1.0,
    ) -> None:
        """
        Args:
            named_params: iterable of (name, param) as in model.named_parameters().
            alpha, beta, rho: FPTT hyper-parameters.
            lmbda: additional scaling of the quadratic term (lambda in the paper).
        """
        self.alpha = max(1e-8, float(alpha))
        self.beta = max(0.0, float(beta))
        self.rho = float(rho)
        self.lmbda = float(lmbda)

        self._state: Dict[str, Dict[str, torch.Tensor]] = {}
        for name, param in named_params:
            sm = param.detach().clone()
            lm = torch.zeros_like(param)
            self._state[name] = {"param": param, "sm": sm, "lm": lm}

        self._device = (
            next(iter(self._state.values()))["param"].device
            if self._state
            else torch.device("cpu")
        )

    def reset(self) -> None:
        """Reset shadow and dual variables."""
        for state in self._state.values():
            state["sm"].copy_(state["param"].detach())
            state["lm"].zero_()

    def loss(self, lmbda: float | None = None) -> torch.Tensor:
        """
        Compute the FPTT regularization term:

            (rho - 1) * <w, λ> + λ_lambda * 0.5 * alpha * ||w - s||^2

        where λ_lambda is either the stored self.lmbda or an override.
        """
        if not self._state:
            return torch.zeros((), device=self._device)

        scale = float(self.lmbda if lmbda is None else lmbda)
        reg = torch.zeros((), device=self._device)
        for state in self._state.values():
            param = state["param"]
            sm = state["sm"]
            lm = state["lm"]
            reg = reg + (self.rho - 1.0) * torch.sum(param * lm)
            reg = reg + scale * 0.5 * self.alpha * torch.sum((param - sm) ** 2)
        return reg

    def step(self) -> None:
        """
        Proximal dynamics update (non-debiased branch):

            lm <- lm - alpha * (w - s)
            s  <- (1 - beta) * s + beta * w - (beta / alpha) * lm
        """
        if not self._state:
            return
        with torch.no_grad():
            for state in self._state.values():
                param = state["param"]
                sm = state["sm"]
                lm = state["lm"]
                delta = param.detach() - sm
                lm.add_(-self.alpha * delta)
                sm.mul_(1.0 - self.beta)
                sm.add_(self.beta * param.detach() - (self.beta / self.alpha) * lm)


class StrictFPTTBase(nn.Module):
    """
    Base module implementing a simple tanh-RNN with FPTT regularization.

    RNN 架构是简化版（tanh 而不是官方 LSTM），但：
    - chunk 切分
    - proximal 正则
    - oracle 混合 loss
    在数学上是对齐官方的。
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float,
        parts: int,
        clip: float,
        alpha: float,
        beta: float,
        rho: float,
        lmbda: float = 1.0,
        optimizer_cls: type[optim.Optimizer] = optim.SGD,
        device: str | torch.device | None = None,
    ) -> None:
        super().__init__()
        self.input_size = int(input_size)
        self.hidden_size = int(hidden_size)
        self.output_size = int(output_size)
        self.eta = float(eta)
        self.parts = max(1, int(parts))
        self.grad_clip = max(0.0, float(clip))
        self.device = torch.device(device or "cpu")

        # Simple tanh-RNN parameters
        input_scale = min(0.1, 1.0 / np.sqrt(max(1, self.input_size)))
        self._W_xh = nn.Parameter(
            torch.randn(self.hidden_size, self.input_size, device=self.device) * input_scale
        )
        self._W_hh = nn.Parameter(
            torch.randn(self.hidden_size, self.hidden_size, device=self.device)
        )
        self._b_h = nn.Parameter(torch.zeros(self.hidden_size, device=self.device))
        self._W_hy = nn.Parameter(
            torch.randn(self.output_size, self.hidden_size, device=self.device) * 0.1
        )
        self._b_y = nn.Parameter(torch.zeros(self.output_size, device=self.device))

        self.optimizer = optimizer_cls(self.parameters(), lr=self.eta)
        self._regularizer = FPTTRegularizer(
            list(self.named_parameters()), alpha, beta, rho, lmbda=lmbda
        )
        self.current_epoch = 0

    # ------------------------------------------------------------------ #
    # Public API expected by comparison scripts
    # ------------------------------------------------------------------ #
    def initialize_weights_with_gain(self, g: float) -> None:
        std_dev = float(g) / np.sqrt(self.hidden_size)
        with torch.no_grad():
            self._W_hh.copy_(torch.randn_like(self._W_hh) * std_dev)
        self.reset_state_buffers()

    def reset_state_buffers(self) -> None:
        self._regularizer.reset()

    def set_epoch(self, epoch: int) -> None:
        self.current_epoch = int(epoch)

    def forward_cycle(
        self, inputs: np.ndarray | torch.Tensor, h_prev: np.ndarray | torch.Tensor
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        """
        Torch-friendly evaluation loop.

        Args:
            inputs: (batch, input_size, time_steps)
            h_prev: (hidden_size, batch)

        Returns:
            outputs: list[time_steps] of tensors (output_size, batch)
            h_last:  (hidden_size, batch)
        """
        return_numpy = not torch.is_tensor(inputs)
        inputs_tensor = self._torch_inputs(inputs)
        if inputs_tensor.dim() == 3 and inputs_tensor.shape[1] == self.input_size:
            inputs_tensor = inputs_tensor.permute(0, 2, 1)
        state = self._torch_state(h_prev)
        outputs: List[torch.Tensor] = []
        with torch.no_grad():
            for t in range(inputs_tensor.size(1)):
                state = self._step_once(state, inputs_tensor[:, t, :])
                logits = torch.matmul(state, self._W_hy.t()) + self._b_y  # (B, C)
                outputs.append(logits.T)  # (C, B)
        if return_numpy:
            outputs_np = [out.detach().cpu().numpy() for out in outputs]
            return outputs_np, state.T.detach().cpu().numpy()
        return outputs, state.T

    # ------------------------------------------------------------------ #
    # Helpers shared by subclasses
    # ------------------------------------------------------------------ #
    def _step_once(self, h_prev: torch.Tensor, inputs_t: torch.Tensor) -> torch.Tensor:
        hidden_term = torch.matmul(h_prev, self._W_hh.t())
        input_term = torch.matmul(inputs_t, self._W_xh.t())
        return torch.tanh(hidden_term + input_term + self._b_h)

    def _clip_and_step(self) -> None:
        if self.grad_clip > 0:
            clip_grad_norm_(self.parameters(), self.grad_clip)
        self.optimizer.step()
        self._regularizer.step()

    def _torch_state(self, h_prev: np.ndarray | torch.Tensor) -> torch.Tensor:
        if torch.is_tensor(h_prev):
            if h_prev.dim() == 2 and h_prev.shape[0] == self.hidden_size:
                return h_prev.T.to(self.device)
            return h_prev.to(self.device)
        return torch.from_numpy(h_prev.T.astype(np.float32)).to(self.device)

    def _torch_inputs(self, array: np.ndarray | torch.Tensor) -> torch.Tensor:
        if torch.is_tensor(array):
            return array.to(self.device, dtype=torch.float32)
        return torch.from_numpy(array.astype(np.float32)).to(self.device)

    def _reg_loss(self, lmbda: float | None = None) -> torch.Tensor:
        return self._regularizer.loss(lmbda=lmbda)

    # ------------------------------------------------------------------ #
    # Numpy compatibility interface
    # ------------------------------------------------------------------ #
    def _tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
        return tensor.detach().cpu().numpy()

    def _assign_param(
        self,
        param: nn.Parameter,
        value: np.ndarray,
        reshape: Tuple[int, ...] | None = None,
    ) -> None:
        arr = np.asarray(value, dtype=np.float32)
        if reshape is not None:
            arr = arr.reshape(reshape)
        tensor = torch.as_tensor(arr, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            param.copy_(tensor)

    @property
    def W_xh(self) -> np.ndarray:
        return self._tensor_to_numpy(self._W_xh)

    @W_xh.setter
    def W_xh(self, value: np.ndarray) -> None:
        self._assign_param(self._W_xh, value)

    @property
    def W_hh(self) -> np.ndarray:
        return self._tensor_to_numpy(self._W_hh)

    @W_hh.setter
    def W_hh(self, value: np.ndarray) -> None:
        self._assign_param(self._W_hh, value)

    @property
    def W_hy(self) -> np.ndarray:
        return self._tensor_to_numpy(self._W_hy)

    @W_hy.setter
    def W_hy(self, value: np.ndarray) -> None:
        self._assign_param(self._W_hy, value)

    @property
    def b_h(self) -> np.ndarray:
        bias = self._tensor_to_numpy(self._b_h)
        return bias.reshape(-1, 1)

    @b_h.setter
    def b_h(self, value: np.ndarray) -> None:
        self._assign_param(self._b_h, value.reshape(-1))

    @property
    def b_y(self) -> np.ndarray:
        bias = self._tensor_to_numpy(self._b_y)
        return bias.reshape(-1, 1)

    @b_y.setter
    def b_y(self, value: np.ndarray) -> None:
        self._assign_param(self._b_y, value.reshape(-1))


class StrictFPTTClassifier(StrictFPTTBase):
    """
    Strict FPTT classifier with shared oracle buffers.

    在分类任务（MNIST-10 / CIFAR-10）上：
    - loss 形式
    - chunk 权重 ( (p+1)/_PARTS )
    - oracle 混合方式
    - 正则项
    与官方实现数学等价（除去 RNN 本身从 LSTM 换成了 tanh-RNN）。
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 1e-3,
        parts: int = 10,
        clip: float = 1.0,
        alpha: float = 0.1,
        beta: float = 0.5,
        rho: float = 0.0,
        lmbda: float = 1.0,
        oracle_momentum: float = 1.0,
        warmup_epochs: int = 20,
        oracle_id: str = "default",
        label_mode: str = "last",
        use_oracle: bool = True,
        optimizer_cls: type[optim.Optimizer] = optim.SGD,
        device: str | torch.device | None = None,
    ) -> None:
        super().__init__(
            input_size,
            hidden_size,
            output_size,
            eta,
            parts,
            clip,
            alpha,
            beta,
            rho,
            lmbda=lmbda,
            optimizer_cls=optimizer_cls,
            device=device,
        )
        self.oracle_momentum = float(np.clip(oracle_momentum, 1e-4, 1.0))
        self.warmup_epochs = max(0, int(warmup_epochs))
        self.oracle_id = oracle_id

        self.label_mode = label_mode  # 'last' (classification) or 'all' (seq2seq)
        self.use_oracle = use_oracle

        # Uniform distribution template for warmup
        self._uniform_template = torch.full(
            (1, self.output_size),
            1.0 / float(self.output_size),
            dtype=torch.float32,
            device=self.device,
        )

    def train_batch(
        self,
        inputs_batch: np.ndarray,
        targets_batch: np.ndarray,
        h_prev_batch: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        """
        Train on a single batch.

        Args:
            inputs_batch:  (batch, input_size, time_steps)
            targets_batch: (batch, num_classes, time_steps); for label_mode='last'
                           we use only the last time step; for label_mode='all'
                           we use every time step (sequence-to-sequence).
            h_prev_batch:  (hidden_size, batch)

        Returns:
            average loss over chunks (float) and final hidden state (hidden_size, batch).
        """
        batch_size, _, time_steps = inputs_batch.shape
        schedule = build_chunk_schedule(time_steps, self.parts)
        total_chunks = len(schedule)
        if total_chunks == 0:
            return 0.0, h_prev_batch

        inputs_tensor = self._torch_inputs(inputs_batch).permute(0, 2, 1)  # (B, T, F)
        state = self._torch_state(h_prev_batch)  # (B, H)
        warmup = self.current_epoch < self.warmup_epochs

        if self.label_mode == "last":
            # Use the last time step's one-hot labels
            labels = self._torch_inputs(targets_batch[:, :, -1])  # (B, C)
            if torch.is_tensor(targets_batch):
                label_indices = (
                    torch.argmax(targets_batch[:, :, -1], dim=1)
                    .detach()
                    .cpu()
                    .numpy()
                    .astype(np.int64)
                )
            else:
                label_indices = np.argmax(targets_batch[:, :, -1], axis=1).astype(np.int64)

            # Oracle is active only on early chunks; the last ones use pure labels
            oracle_cutoff = max(0, min(self.parts - 1, total_chunks - 1))

            oracle: ClassOracleBuffer | None = None
            if self.use_oracle:
                oracle = OracleBufferStore.get(
                    self.oracle_id,
                    self.output_size,
                    self.parts,
                    self.oracle_momentum,
                )
                oracle.ensure(total_chunks)

            total_loss = 0.0
            for chunk_idx, (start, end) in enumerate(schedule):
                chunk_inputs = inputs_tensor[:, start:end, :]
                if chunk_inputs.size(1) == 0:
                    continue

                state, loss_value = self._run_classification_chunk(
                    chunk_inputs,
                    state,
                    labels,
                    label_indices,
                    chunk_idx,
                    total_chunks,
                    oracle_cutoff,
                    warmup,
                    oracle,
                )
                total_loss += loss_value

            return total_loss / max(1, len(schedule)), state.detach().cpu().numpy().T

        if self.label_mode == "all":
            # Sequence-to-sequence targets (e.g. language modeling): (B, T, C)
            if self.use_oracle:
                raise NotImplementedError(
                    "label_mode='all' currently does not support oracle mixing; set use_oracle=False."
                )

            targets_tensor = self._torch_inputs(targets_batch).permute(0, 2, 1)  # (B, T, C)
            total_loss = 0.0

            for chunk_idx, (start, end) in enumerate(schedule):
                chunk_inputs = inputs_tensor[:, start:end, :]
                chunk_targets = targets_tensor[:, start:end, :]
                if chunk_inputs.size(1) == 0:
                    continue
                chunk_weight = float(chunk_idx + 1) / float(max(1, total_chunks))
                state, loss_value = self._run_sequence_chunk(
                    chunk_inputs,
                    chunk_targets,
                    state,
                    chunk_weight,
                )
                total_loss += loss_value

            return total_loss / max(1, len(schedule)), state.detach().cpu().numpy().T

        raise NotImplementedError(
            "StrictFPTTClassifier only supports label_mode='last' or label_mode='all'."
        )

    def _run_classification_chunk(
        self,
        chunk_inputs: torch.Tensor,  # (B, T_chunk, F)
        state: torch.Tensor,  # (B, H)
        labels: torch.Tensor,  # (B, C)
        label_indices: np.ndarray,  # (B,)
        chunk_idx: int,
        total_chunks: int,
        oracle_cutoff: int,
        warmup: bool,
        oracle: ClassOracleBuffer | None,
    ) -> Tuple[torch.Tensor, float]:
        # Forward unroll over the chunk
        for t in range(chunk_inputs.size(1)):
            state = self._step_once(state, chunk_inputs[:, t, :])

        logits = torch.matmul(state, self._W_hy.t()) + self._b_y  # (B, C)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.softmax(logits, dim=1)

        # Relative progress of this chunk within the sequence
        alpha = float(chunk_idx + 1) / float(max(1, total_chunks))

        # Oracle is only active on early parts; last parts use pure labels
        oracle_active = (oracle is not None) and (chunk_idx < oracle_cutoff)
        oracle_weight = (1.0 - alpha) if oracle_active else 0.0

        surrogate: torch.Tensor | None = None
        if oracle_active:
            if warmup:
                # Warmup: uniform distribution per sample
                surrogate = self._uniform_template.expand_as(labels)
            else:
                # Class-conditional oracle: (classes, batch) -> (batch, classes)
                surrogate_np = oracle.get(label_indices, chunk_idx)
                surrogate = torch.from_numpy(surrogate_np.T).to(self.device)

        # Mixed target: alpha * one-hot + (1-alpha) * oracle_prob
        if surrogate is None:
            mix_target = labels
        else:
            mix_target = alpha * labels + oracle_weight * surrogate

        # Cross-entropy with soft targets (equivalent to clf_loss + oracle_loss)
        loss_ce = torch.sum(-mix_target * log_probs, dim=1).mean()
        loss = loss_ce + self._reg_loss()

        self.optimizer.zero_grad()
        loss.backward()
        self._clip_and_step()
        next_state = state.detach()

        # Update the oracle buffer from misclassified examples
        if oracle_active and (not warmup) and (surrogate is not None) and (oracle is not None):
            probs_np = probs.detach().cpu().numpy().T  # (C, B)
            preds = np.argmax(probs_np, axis=0)  # (B,)
            oracle.update(label_indices, chunk_idx, probs_np, preds)

        return next_state, float(loss_ce.item())

    def _run_sequence_chunk(
        self,
        chunk_inputs: torch.Tensor,   # (B, T_chunk, F)
        chunk_targets: torch.Tensor,  # (B, T_chunk, C) one-hot/soft targets
        state: torch.Tensor,          # (B, H)
        chunk_weight: float,
    ) -> Tuple[torch.Tensor, float]:
        """
        Sequence-to-sequence chunk training (no oracle):
          - unroll over the chunk
          - compute CE over every time step
          - weight by chunk progress (FPTT-style)
        """
        logits_steps: List[torch.Tensor] = []
        for t in range(chunk_inputs.size(1)):
            state = self._step_once(state, chunk_inputs[:, t, :])
            logits_steps.append(torch.matmul(state, self._W_hy.t()) + self._b_y)  # (B, C)

        if not logits_steps:
            return state.detach(), 0.0

        logits = torch.stack(logits_steps, dim=1)  # (B, T_chunk, C)
        log_probs = torch.log_softmax(logits, dim=2)
        loss_ce = torch.sum(-chunk_targets * log_probs, dim=2).mean()

        loss_core = float(chunk_weight) * loss_ce
        loss = loss_core + self._reg_loss()
        self.optimizer.zero_grad()
        loss.backward()
        self._clip_and_step()
        return state.detach(), float(loss_core.item())


class StrictFPTTRegressor(StrictFPTTBase):
    """
    Strict FPTT for regression benchmarks（比如加法任务）.

    官方 repo 里没有现成的回归 FPTT 代码，这里是完全按同一套
    chunk + proximal 逻辑扩展出来的。
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 1e-3,
        parts: int = 8,
        clip: float = 1.0,
        alpha: float = 0.1,
        beta: float = 0.5,
        rho: float = 0.0,
        lmbda: float = 1.0,
        optimizer_cls: type[optim.Optimizer] = optim.SGD,
        device: str | torch.device | None = None,
    ) -> None:
        super().__init__(
            input_size,
            hidden_size,
            output_size,
            eta,
            parts,
            clip,
            alpha,
            beta,
            rho,
            lmbda=lmbda,
            optimizer_cls=optimizer_cls,
            device=device,
        )

    def train_batch(
        self,
        inputs_batch: np.ndarray,
        targets_batch: np.ndarray,
        h_prev_batch: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        """
        Regression training step.

        Args:
            inputs_batch:  (batch, input_size, time_steps)
            targets_batch: (batch, output_size, time_steps)
            h_prev_batch:  (hidden_size, batch)
        """
        batch_size, _, time_steps = inputs_batch.shape
        schedule = build_chunk_schedule(time_steps, self.parts)
        if not schedule:
            return 0.0, h_prev_batch

        inputs_tensor = self._torch_inputs(inputs_batch).permute(0, 2, 1)   # (B, T, F)
        targets_tensor = self._torch_inputs(targets_batch).permute(0, 2, 1) # (B, T, O)
        state = self._torch_state(h_prev_batch)
        total_chunks = len(schedule)
        total_loss = 0.0

        for chunk_idx, (start, end) in enumerate(schedule):
            chunk_inputs = inputs_tensor[:, start:end, :]
            chunk_targets = targets_tensor[:, start:end, :]
            if chunk_inputs.size(1) == 0:
                continue
            state, loss_value = self._run_regression_chunk(
                chunk_inputs,
                chunk_targets,
                state,
                float(chunk_idx + 1) / float(total_chunks),
            )
            total_loss += loss_value

        return total_loss / max(1, len(schedule)), state.detach().cpu().numpy().T

    def _run_regression_chunk(
        self,
        chunk_inputs: torch.Tensor,
        chunk_targets: torch.Tensor,
        state: torch.Tensor,
        chunk_weight: float,
    ) -> Tuple[torch.Tensor, float]:
        """
        Unroll one chunk and accumulate weighted MSE + FPTT regularizer.
        """
        losses: List[torch.Tensor] = []
        for t in range(chunk_inputs.size(1)):
            state = self._step_once(state, chunk_inputs[:, t, :])
            preds = torch.matmul(state, self._W_hy.t()) + self._b_y
            error = preds - chunk_targets[:, t, :]
            mse = 0.5 * torch.mean(error ** 2)
            losses.append(chunk_weight * mse)

        if not losses:
            return state, 0.0

        loss_core = torch.stack(losses).mean()
        loss = loss_core + self._reg_loss()
        self.optimizer.zero_grad()
        loss.backward()
        self._clip_and_step()
        return state.detach(), float(loss_core.item())


TorchStrictFPTTClassifier = StrictFPTTClassifier
TorchStrictFPTTRegressor = StrictFPTTRegressor
