from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, List

import numpy as np


@dataclass
class RNNParameters:
    """Container for all trainable tensors of the shared tanh RNN."""

    W_xh: np.ndarray
    W_hh: np.ndarray
    b_h: np.ndarray
    W_hy: np.ndarray
    b_y: np.ndarray

    def clone(self) -> "RNNParameters":
        return RNNParameters(
            W_xh=self.W_xh.copy(),
            W_hh=self.W_hh.copy(),
            b_h=self.b_h.copy(),
            W_hy=self.W_hy.copy(),
            b_y=self.b_y.copy(),
        )


def initialize_rnn_parameters(
    input_size: int,
    hidden_size: int,
    output_size: int,
    *,
    gain: float = 1.0,
    seed: int = 0,
) -> RNNParameters:
    """Create a deterministic parameter set shared across algorithms."""

    rng = np.random.default_rng(int(seed))
    W_xh = (0.1 * rng.standard_normal((hidden_size, input_size))).astype(np.float32)
    W_hh = rng.standard_normal((hidden_size, hidden_size)).astype(np.float32)
    W_hh *= float(gain) / np.sqrt(max(1, hidden_size))
    b_h = np.zeros((hidden_size, 1), dtype=np.float32)
    W_hy = (0.1 * rng.standard_normal((output_size, hidden_size))).astype(np.float32)
    b_y = np.zeros((output_size, 1), dtype=np.float32)
    return RNNParameters(W_xh=W_xh, W_hh=W_hh, b_h=b_h, W_hy=W_hy, b_y=b_y)


def apply_rnn_parameters(model: object, params: RNNParameters) -> None:
    """Copy the shared parameters into the provided model (FPTT or e-prop)."""

    for attr in ("W_xh", "W_hh", "b_h", "W_hy", "b_y"):
        if not hasattr(model, attr):
            raise AttributeError(f"Model {model} does not expose attribute '{attr}'.")
    model.W_xh = params.W_xh
    model.W_hh = params.W_hh
    model.b_h = params.b_h
    model.W_hy = params.W_hy
    model.b_y = params.b_y


def extract_rnn_parameters(model: object) -> RNNParameters:
    """Read the parameter tensors back from a model instance."""

    return RNNParameters(
        W_xh=np.asarray(model.W_xh, dtype=np.float32).copy(),
        W_hh=np.asarray(model.W_hh, dtype=np.float32).copy(),
        b_h=np.asarray(model.b_h, dtype=np.float32).copy(),
        W_hy=np.asarray(model.W_hy, dtype=np.float32).copy(),
        b_y=np.asarray(model.b_y, dtype=np.float32).copy(),
    )


def build_epoch_seeds(base_seed: int, epochs: int) -> List[int]:
    """Return a list of deterministic seeds (one per epoch)."""

    base_seed = int(base_seed)
    epochs = max(1, int(epochs))
    seed_seq = np.random.SeedSequence(base_seed)
    epoch_seqs = seed_seq.spawn(epochs)
    seeds: List[int] = []
    for child in epoch_seqs:
        state = child.generate_state(1, dtype=np.uint32)
        seeds.append(int(state[0]))
    return seeds


def build_batch_seed_grid(base_seed: int, epochs: int, steps_per_epoch: int) -> List[List[int]]:
    """Nested list of seeds [epoch][step] so every algorithm sees identical batches."""

    steps_per_epoch = max(1, int(steps_per_epoch))
    grid: List[List[int]] = []
    for epoch_seed in build_epoch_seeds(base_seed, epochs):
        rng = np.random.default_rng(epoch_seed)
        epoch_list = [
            int(seed) for seed in rng.integers(low=0, high=2**31, size=steps_per_epoch, dtype=np.int64)
        ]
        grid.append(epoch_list)
    return grid


class EpochBatchScheduler:
    """Deterministic index scheduler shared by every algorithm."""

    def __init__(
        self,
        num_examples: int,
        batch_size: int,
        epochs: int,
        *,
        base_seed: int,
        drop_last: bool = False,
    ) -> None:
        if num_examples <= 0:
            raise ValueError("num_examples must be positive.")
        if batch_size <= 0:
            raise ValueError("batch_size must be positive.")
        self.num_examples = int(num_examples)
        self.batch_size = int(batch_size)
        self.epochs = max(1, int(epochs))
        self.drop_last = bool(drop_last)
        rng = np.random.default_rng(int(base_seed))
        self._schedule: List[List[np.ndarray]] = []
        indices = np.arange(self.num_examples)
        for _ in range(self.epochs):
            rng.shuffle(indices)
            epoch_batches: List[np.ndarray] = []
            for start in range(0, self.num_examples, self.batch_size):
                stop = start + self.batch_size
                if stop > self.num_examples and self.drop_last:
                    break
                batch = indices[start:stop]
                if batch.size == 0:
                    continue
                epoch_batches.append(batch.copy())
            self._schedule.append(epoch_batches)

    def epoch_batches(self, epoch: int) -> Iterable[np.ndarray]:
        """Yield deep copies of the scheduled batch indices for `epoch`."""

        if epoch < 0 or epoch >= self.epochs:
            raise IndexError(f"epoch {epoch} out of range [0, {self.epochs}).")
        for batch in self._schedule[epoch]:
            yield batch.copy()

    def as_list(self) -> List[List[np.ndarray]]:
        """Return the full schedule (copies) for debugging/plotting."""

        return [[batch.copy() for batch in epoch] for epoch in self._schedule]
