from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
from typing_extensions import Self


class MLPAgent(nn.Module):
    """
    MLP-based agent.
    Actions are normalized to be in [-1, 1].
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        layers: Optional[List[int]] = None,
        activation: str = "relu",
    ):
        super().__init__()

        if layers is None:
            layers = []

        # Configure activation function
        self.activation = {"relu": nn.ReLU(), "tanh": nn.Tanh()}[activation.lower()]

        # Build layer dimensions
        layer_dims = [state_dim] + layers + [action_dim]

        # Create neural network layers
        self.layers = nn.ModuleList()
        for i in range(len(layer_dims) - 1):
            self.layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1]))

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize the weights using orthogonal initialization with proper scaling.
        This initialization has been shown to work particularly well for deep RL.
        """
        if isinstance(module, nn.Linear):
            gain = np.sqrt(2.0) if isinstance(self.activation, nn.ReLU) else 1.0
            nn.init.orthogonal_(module.weight, gain=gain)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network"""
        x = state
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # Apply activation to all layers except the last
            if i < len(self.layers) - 1:
                x = self.activation(x)
            else:
                x = torch.tanh(x)
        return x

    def to_numpy(self) -> np.ndarray:
        """Flatten all parameters into a single numpy array"""
        params = []
        for param in self.parameters():
            params.append(param.data.cpu().numpy().flatten())
        return np.concatenate(params)

    def from_numpy(self, params: np.ndarray) -> Self:
        """Load parameters from a flattened numpy array"""
        start_idx = 0
        for param in self.parameters():
            param_size = np.prod(param.shape)
            param_data = params[start_idx : start_idx + param_size]
            param.data = (
                torch.from_numpy(param_data.reshape(param.shape).copy())
                .to(param.device)
                .to(param.dtype)
            )
            start_idx += param_size
        return self  # So that we can chain calls

    def num_params(self) -> int:
        """Return the total number of parameters in the network"""
        return sum(p.numel() for p in self.parameters())

    def act(self, state: np.ndarray) -> np.ndarray:
        """
        Get action for a given state.
        Handles conversion between numpy and torch tensors.
        """
        with torch.no_grad():
            state_tensor = (
                torch.from_numpy(state).float().to(next(iter(self.parameters())).device)
            )
            action_tensor = self.forward(state_tensor)
            return action_tensor.cpu().numpy()


class ToeplitzAgent(nn.Module):
    """
    Neural network policy using Toeplitz matrices instead of fully-connected layers.

    A Toeplitz matrix has constant diagonals (each diagonal contains the same value).
    This structure is more parameter-efficient, requiring only d_in + d_out - 1 parameters
    instead of d_in * d_out for a standard layer.
    See https://arxiv.org/abs/1804.02395
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        layers: Optional[List[int]] = None,
        activation: str = "relu",
    ):
        super().__init__()

        # Set default values
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dims = layers if layers else []

        # Configure activation function
        self.activation = {"relu": nn.ReLU(), "tanh": nn.Tanh()}[activation.lower()]

        # Initialize Parameter vectors for Toeplitz matrices
        self.toeplitz_vectors = nn.ParameterList()
        self.biases = nn.ParameterList()

        # Input to first hidden layer
        self.toeplitz_vectors.append(
            nn.Parameter(self._weight_init(self.state_dim + self.hidden_dims[0] - 1))
        )
        self.biases.append(nn.Parameter(torch.zeros(self.hidden_dims[0])))

        # Hidden layers
        for i in range(1, len(self.hidden_dims)):
            self.toeplitz_vectors.append(
                nn.Parameter(
                    self._weight_init(self.hidden_dims[i - 1] + self.hidden_dims[i] - 1)
                )
            )
            self.biases.append(nn.Parameter(torch.zeros(self.hidden_dims[i])))

        # Last hidden to output
        self.toeplitz_vectors.append(
            nn.Parameter(self._weight_init(self.hidden_dims[-1] + self.action_dim - 1))
        )
        self.biases.append(nn.Parameter(torch.zeros(self.action_dim)))

        # Pre-build Toeplitz matrices
        self.toeplitz_matrices = []
        self._build_toeplitz_matrices()

    def _weight_init(self, size):
        """Xavier/Glorot-style initialization to maintain activation scales"""
        return torch.randn(size) / torch.sqrt(torch.tensor(size, dtype=torch.float32))

    def _build_toeplitz_matrices(self):
        """Build all Toeplitz matrices from parameter vectors"""
        self.toeplitz_matrices = []

        for i, vector in enumerate(self.toeplitz_vectors):
            if i == 0:
                self.toeplitz_matrices.append(
                    self._build_toeplitz_matrix(
                        self.hidden_dims[0], self.state_dim, vector
                    )
                )
            elif i == len(self.toeplitz_vectors) - 1:
                self.toeplitz_matrices.append(
                    self._build_toeplitz_matrix(
                        self.action_dim, self.hidden_dims[-1], vector
                    )
                )
            else:
                self.toeplitz_matrices.append(
                    self._build_toeplitz_matrix(
                        self.hidden_dims[i], self.hidden_dims[i - 1], vector
                    )
                )

    def _build_toeplitz_matrix(self, rows, cols, vector):
        """
        Build a Toeplitz matrix from a parameter vector efficiently.
        Vector contains first column followed by first row (minus first element).
        """
        device = vector.device

        # Create row and column indices
        row_idx = torch.arange(rows, device=device).view(-1, 1)
        col_idx = torch.arange(cols, device=device).view(1, -1)

        # Calculate indices into the parameter vector
        indices = row_idx - col_idx

        # Create the matrix by indexing into the vector
        matrix = vector[indices]

        return matrix

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward pass using pre-built Toeplitz matrices"""
        x = state

        # Process through hidden layers
        for i in range(len(self.hidden_dims)):
            x = self.activation(
                torch.matmul(x, self.toeplitz_matrices[i].T) + self.biases[i]
            )

        # Output layer with tanh activation
        action = torch.tanh(
            torch.matmul(x, self.toeplitz_matrices[-1].T) + self.biases[-1]
        )
        return action

    def to_numpy(self) -> np.ndarray:
        """Flatten all parameters into a single numpy array"""
        params = []
        for param in self.parameters():
            params.append(param.data.cpu().numpy().flatten())
        return np.concatenate(params)

    def from_numpy(self, params: np.ndarray) -> Self:
        """Load parameters from a flattened numpy array and rebuild matrices"""
        start_idx = 0
        for param in self.parameters():
            param_size = np.prod(param.shape)
            param_data = params[start_idx : start_idx + param_size]
            param.data = (
                torch.from_numpy(param_data.reshape(param.shape).copy())
                .to(param.device)
                .to(param.dtype)
            )
            start_idx += param_size

        # Rebuild Toeplitz matrices with the new parameters
        self._build_toeplitz_matrices()

        return self

    def num_params(self) -> int:
        """Return the total number of parameters in the network"""
        return sum(p.numel() for p in self.parameters())

    def act(self, state: np.ndarray) -> np.ndarray:
        """Get action for a given state, handling numpy/torch conversion"""
        with torch.no_grad():
            state_tensor = (
                torch.from_numpy(state).float().to(next(iter(self.parameters())).device)
            )
            action_tensor = self.forward(state_tensor)
            return action_tensor.cpu().numpy()
