from __future__ import annotations

from typing import List, Tuple, Optional

import numpy as np
import torch


DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _to_tensor(
    array: np.ndarray | torch.Tensor,
    device: torch.device,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    if torch.is_tensor(array):
        return array.to(device=device, dtype=dtype)
    return torch.as_tensor(array, dtype=dtype, device=device)


def _softmax(logits: torch.Tensor) -> torch.Tensor:
    logits_shifted = logits - torch.max(logits, dim=0, keepdim=True).values
    exp_logits = torch.exp(logits_shifted)
    return exp_logits / (torch.sum(exp_logits, dim=0, keepdim=True) + 1e-12)


def _clip_grads(grads: List[torch.Tensor], max_norm: float) -> List[torch.Tensor]:
    if max_norm <= 0:
        return grads
    total_norm_sq = torch.zeros((), device=grads[0].device)
    for g in grads:
        total_norm_sq = total_norm_sq + torch.sum(g ** 2)
    total_norm = torch.sqrt(total_norm_sq)
    if not torch.isfinite(total_norm) or total_norm <= max_norm:
        return grads
    scale = max_norm / (total_norm + 1e-8)
    return [g * scale for g in grads]


class TorchEPropRNN:
    """
    Rate-based E-Prop (specifically the RFLO approximation).

    Features:
    - Uses eligibility traces with scalar decay.
    - Supports Online Learning (weights updated every timestep).
    - Includes Gradient Clipping (Default 5.0).
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 1e-3,
        decay_lambda: float = 0.95,
        feedback: str = "symmetric",
        seed: int = 1234,
        loss_mode: str = "ce",
        max_grad_norm: float = 5.0,
        device: torch.device | str | None = None,
    ) -> None:
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.eta = float(eta)
        self.decay_lambda = float(decay_lambda)
        self.loss_mode = loss_mode
        self.max_grad_norm = float(max_grad_norm)
        self.step_weights: torch.Tensor | None = None
        self.device = torch.device(device) if device is not None else DEFAULT_DEVICE

        feedback = feedback.lower()
        if feedback not in {"random", "symmetric"}:
            raise ValueError("feedback must be 'random' or 'symmetric'.")
        self.feedback_type = feedback

        rng = np.random.default_rng(seed)
        input_scale = min(0.1, 1.0 / np.sqrt(max(1, input_size)))
        self.W_xh = _to_tensor(
            rng.standard_normal((hidden_size, input_size)).astype(np.float32) * input_scale,
            self.device,
        )
        self.W_hh = _to_tensor(
            rng.standard_normal((hidden_size, hidden_size)).astype(np.float32) * 0.1,
            self.device,
        )
        self.b_h = torch.zeros((hidden_size, 1), dtype=torch.float32, device=self.device)
        self.W_hy = _to_tensor(
            rng.standard_normal((output_size, hidden_size)).astype(np.float32) * 0.1,
            self.device,
        )
        self.b_y = torch.zeros((output_size, 1), dtype=torch.float32, device=self.device)

        self.B_fb: Optional[torch.Tensor] = None
        if self.feedback_type == "random":
            self.B_fb = _to_tensor(
                rng.standard_normal((hidden_size, output_size)).astype(np.float32)
                / np.sqrt(max(1, output_size)),
                self.device,
            )

    def initialize_weights_with_gain(self, g: float) -> None:
        rng = np.random.default_rng(int(g * 1000))
        input_scale = min(0.1, 1.0 / np.sqrt(max(1, self.input_size)))
        self.W_xh = _to_tensor(
            rng.standard_normal(self.W_xh.shape).astype(np.float32) * input_scale,
            self.device,
        )
        self.W_hh = _to_tensor(
            rng.standard_normal(self.W_hh.shape).astype(np.float32) * (g / np.sqrt(self.hidden_size)),
            self.device,
        )
        self.b_h = torch.zeros(self.b_h.shape, dtype=torch.float32, device=self.device)
        self.W_hy = _to_tensor(
            rng.standard_normal(self.W_hy.shape).astype(np.float32) * 0.1,
            self.device,
        )
        self.b_y = torch.zeros(self.b_y.shape, dtype=torch.float32, device=self.device)

    def _learning_signal(self, dL_dyhat: torch.Tensor) -> torch.Tensor:
        if self.feedback_type == "symmetric":
            return self.W_hy.T @ dL_dyhat
        if self.B_fb is None:
            raise RuntimeError("Random feedback matrix B_fb is None.")
        return self.B_fb @ dL_dyhat

    def train_batch(
        self,
        inputs_batch: np.ndarray | torch.Tensor,
        targets_batch: np.ndarray | torch.Tensor,
        h_prev_batch: np.ndarray | torch.Tensor,
    ) -> Tuple[float, torch.Tensor]:
        inputs = _to_tensor(inputs_batch, self.device, dtype=torch.float32)
        targets = _to_tensor(targets_batch, self.device, dtype=torch.float32)
        h_prev = _to_tensor(h_prev_batch, self.device, dtype=torch.float32)
        batch_size, _, time_steps = inputs.shape
        eps = 1e-12

        total_loss = torch.zeros((), device=self.device)
        step_weights = self.step_weights
        if step_weights is not None and not torch.is_tensor(step_weights):
            step_weights = _to_tensor(step_weights, self.device, dtype=torch.float32)
        weight_norm = float(step_weights.sum().item()) if step_weights is not None else float(time_steps)
        weight_norm = max(weight_norm, 1.0)

        with torch.no_grad():
            if self.decay_lambda <= 0.0:
                # Stop-gradient (detach) path: eligibility traces have no temporal memory when decay_lambda == 0,
                # so we can avoid allocating (B,H,H) tensors and compute the same gradients in 2D.
                for t in range(time_steps):
                    step_weight = step_weights[t] if step_weights is not None else 1.0
                    I_t = inputs[:, :, t].T
                    y_target_t = targets[:, :, t].T

                    x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
                    h_t = torch.tanh(x_t)
                    y_hat_t = self.W_hy @ h_t + self.b_y

                    if self.loss_mode == "ce":
                        probs = _softmax(y_hat_t)
                        dL_dyhat = (probs - y_target_t) * step_weight
                        loss_t = step_weight * -torch.mean(torch.sum(y_target_t * torch.log(probs + eps), dim=0))
                    else:
                        error_out = y_hat_t - y_target_t
                        dL_dyhat = error_out * step_weight
                        loss_t = step_weight * 0.5 * torch.mean(torch.sum(error_out**2, dim=0))

                    total_loss = total_loss + loss_t

                    l_t = self._learning_signal(dL_dyhat)
                    psi = (1.0 - h_t**2).T  # (B, H)

                    h_prev_T = h_prev.T  # (B, H)
                    I_t_T = I_t.T        # (B, I)

                    l_batch = l_t.T       # (B, H)
                    q = l_batch * psi     # (B, H)

                    dW_hh_t = (q.T @ h_prev_T) / batch_size
                    dW_xh_t = (q.T @ I_t_T) / batch_size
                    db_h_t = torch.sum(q, dim=0, keepdim=True).T / batch_size

                    dW_hy_t = (dL_dyhat @ h_t.T) / batch_size
                    db_y_t = torch.sum(dL_dyhat, dim=1, keepdim=True) / batch_size

                    dW_xh_t, dW_hh_t, db_h_t, dW_hy_t, db_y_t = _clip_grads(
                        [dW_xh_t, dW_hh_t, db_h_t, dW_hy_t, db_y_t],
                        self.max_grad_norm,
                    )

                    self.W_xh.add_(-self.eta * dW_xh_t)
                    self.W_hh.add_(-self.eta * dW_hh_t)
                    self.b_h.add_(-self.eta * db_h_t)
                    self.W_hy.add_(-self.eta * dW_hy_t)
                    self.b_y.add_(-self.eta * db_y_t)

                    h_prev = h_t

                return float((total_loss / weight_norm).item()), h_prev

            e_W_xh = torch.zeros(
                (batch_size, self.hidden_size, self.input_size),
                dtype=torch.float32,
                device=self.device,
            )
            e_W_hh = torch.zeros(
                (batch_size, self.hidden_size, self.hidden_size),
                dtype=torch.float32,
                device=self.device,
            )
            e_b_h = torch.zeros((batch_size, self.hidden_size), dtype=torch.float32, device=self.device)

            for t in range(time_steps):
                step_weight = step_weights[t] if step_weights is not None else 1.0
                I_t = inputs[:, :, t].T
                y_target_t = targets[:, :, t].T

                x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
                h_t = torch.tanh(x_t)
                y_hat_t = self.W_hy @ h_t + self.b_y

                if self.loss_mode == "ce":
                    probs = _softmax(y_hat_t)
                    dL_dyhat = (probs - y_target_t) * step_weight
                    loss_t = step_weight * -torch.mean(torch.sum(y_target_t * torch.log(probs + eps), dim=0))
                else:
                    error_out = y_hat_t - y_target_t
                    dL_dyhat = error_out * step_weight
                    loss_t = step_weight * 0.5 * torch.mean(torch.sum(error_out**2, dim=0))

                total_loss = total_loss + loss_t

                l_t = self._learning_signal(dL_dyhat)
                psi = (1.0 - h_t**2).T

                h_prev_T = h_prev.T
                I_t_T = I_t.T

                e_W_hh = self.decay_lambda * e_W_hh + psi[:, :, None] * h_prev_T[:, None, :]
                e_W_xh = self.decay_lambda * e_W_xh + psi[:, :, None] * I_t_T[:, None, :]
                e_b_h = self.decay_lambda * e_b_h + psi

                l_batch = l_t.T

                dW_hh_t = torch.einsum("bi,bij->ij", l_batch, e_W_hh) / batch_size
                dW_xh_t = torch.einsum("bi,bij->ij", l_batch, e_W_xh) / batch_size
                db_h_t = torch.sum(l_batch * e_b_h, dim=0, keepdim=True).T / batch_size

                dW_hy_t = (dL_dyhat @ h_t.T) / batch_size
                db_y_t = torch.sum(dL_dyhat, dim=1, keepdim=True) / batch_size

                dW_xh_t, dW_hh_t, db_h_t, dW_hy_t, db_y_t = _clip_grads(
                    [dW_xh_t, dW_hh_t, db_h_t, dW_hy_t, db_y_t],
                    self.max_grad_norm,
                )

                self.W_xh.add_(-self.eta * dW_xh_t)
                self.W_hh.add_(-self.eta * dW_hh_t)
                self.b_h.add_(-self.eta * db_h_t)
                self.W_hy.add_(-self.eta * dW_hy_t)
                self.b_y.add_(-self.eta * db_y_t)

                h_prev = h_t

        return float((total_loss / weight_norm).item()), h_prev

    def forward_cycle(
        self, inputs: np.ndarray | torch.Tensor, h_prev: np.ndarray | torch.Tensor
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        return_numpy = not torch.is_tensor(inputs)
        inputs_t = _to_tensor(inputs, self.device, dtype=torch.float32)
        h_prev_t = _to_tensor(h_prev, self.device, dtype=torch.float32)
        outputs: List[torch.Tensor] = []
        with torch.no_grad():
            for t in range(inputs_t.shape[2]):
                I_t = inputs_t[:, :, t].T
                x_t = self.W_hh @ h_prev_t + self.W_xh @ I_t + self.b_h
                h_prev_t = torch.tanh(x_t)
                outputs.append(self.W_hy @ h_prev_t + self.b_y)
        if return_numpy:
            outputs_np = [out.detach().cpu().numpy() for out in outputs]
            return outputs_np, h_prev_t.detach().cpu().numpy()
        return outputs, h_prev_t

    def zero_state(self, batch_size: int) -> torch.Tensor:
        batch = max(1, int(batch_size))
        return torch.zeros((self.hidden_size, batch), dtype=torch.float32, device=self.device)


StandardEPropRNN = TorchEPropRNN
