import torch.nn as nn
import torch.nn.functional as F
from tools import feature_list
import torch.nn.init as init
import torch


class ValueNetwork(nn.Module):
    def __init__(self, env):
        """
        ValueNetwork: Value Function Network for Reinforcement Learning

        A neural network model that estimates the value function for reinforcement learning
        in supply chain optimization. The model processes state information to predict
        the expected future rewards for different actions.

        Architecture:
        - Input layer: Processes state features (product, order, customer, shipping info)
        - Hidden layers: Fully connected layers with ReLU activation
        - Output layer: Produces value estimates for different actions

        Args:
            env: Environment object containing configuration and dataset information

        Attributes:
            env: Environment configuration and dataset information
            feature_dim (int): Total dimension of input features
            fc1 (nn.Linear): First fully connected layer
            fc3 (nn.Linear): Output layer for value estimation
        """
        super(ValueNetwork, self).__init__()
        self.env = env
        feature_dim = len(
            feature_list.product_info[self.env.args.dataset]
            + feature_list.order_info[self.env.args.dataset]
            + feature_list.customer_info[self.env.args.dataset]
            + feature_list.shipping_info[self.env.args.dataset]
        )
        self.fc1 = nn.Linear(feature_dim, 16)

        self.fc3 = nn.Linear(16, 4)
        self.to(self.env.device)

    def _initialize_weights(self):
        """
        Initialize the weights of the neural network layers.

        This method applies different initialization strategies for different layers:
        - Hidden layers: Kaiming uniform initialization for ReLU activation
        - Output layer: Xavier uniform initialization for better gradient flow

        The initialization helps with:
        - Preventing vanishing/exploding gradients
        - Ensuring proper weight distribution for activation functions
        - Improving training stability and convergence
        """
        init.kaiming_uniform_(self.fc1.weight, nonlinearity="relu")
        init.zeros_(self.fc1.bias)
        init.xavier_uniform_(self.fc3.weight)
        init.zeros_(self.fc3.bias)

    def forward(self, state):
        """
        Forward pass of the value network.

        Args:
            state (torch.Tensor): Input state features

        Returns:
            torch.Tensor: Value estimates for each action
        """
        x = F.relu(self.fc1(state))
        return self.fc3(x)