import sys
import random
import torch
import torch.nn as nn
from tools import feature_list


class S_SimDec(nn.Module):
    def __init__(self, env):
        """
        S_SimDec: Simulation-Decision Model

        A neural network model that combines simulation and decision-making capabilities
        for supply chain optimization. The model processes product, order, customer, and
        shipping information to predict outcomes and make decisions.

        Architecture:
        - Feature transformation layers for different data types (product, order, customer, shipping)
        - Embedding layer for categorical features
        - Bidirectional LSTM encoder for sequence processing
        - LSTM decoder for output generation
        - Output layers for classification tasks

        Args:
            env: Environment object containing configuration and dataset information

        Attributes:
            c_num (int): Total number of categorical features
            c_transform4p (nn.Linear): Linear transformation for product features
            c_transform4o (nn.Linear): Linear transformation for order features
            c_transform4c (nn.Linear): Linear transformation for customer features
            c_transform4s (nn.Linear): Linear transformation for shipping features
            pooling_fc (nn.Linear): Pooling layer for feature aggregation
            embedding (nn.Embedding): Embedding layer for categorical variables
            fc (nn.Linear): Fully connected layer for feature processing
            encoder_lstm (nn.LSTM): Bidirectional LSTM encoder
            decoder_lstm (nn.LSTM): LSTM decoder
            output_layer (nn.ModuleList): Output layers for different prediction tasks
        """
        # Initialize the parent class
        super(S_SimDec, self).__init__()
        self.env = env

        # Calculate total number of categorical features across all data types
        self.c_num = len(
            feature_list.product_info[self.env.args.dataset]
            + feature_list.order_info[self.env.args.dataset]
            + feature_list.customer_info[self.env.args.dataset]
            + feature_list.shipping_info[self.env.args.dataset]
        )

        # Feature transformation layers for different data types
        # Transform product features to embedding dimension
        self.c_transform4p = nn.Linear(
            len(feature_list.product_info[self.env.args.dataset]),
            self.env.args.embed_dim,
        )
        # Transform order features to embedding dimension
        self.c_transform4o = nn.Linear(
            len(feature_list.order_info[self.env.args.dataset]), self.env.args.embed_dim
        )
        # Transform customer features to embedding dimension
        self.c_transform4c = nn.Linear(
            len(feature_list.customer_info[self.env.args.dataset]),
            self.env.args.embed_dim,
        )
        # Transform shipping features to embedding dimension
        self.c_transform4s = nn.Linear(
            len(feature_list.shipping_info[self.env.args.dataset]),
            self.env.args.embed_dim,
        )

        # Pooling layer for feature aggregation
        self.pooling_fc = nn.Linear(self.c_num, self.env.args.embed_dim)

        # Embedding layer for categorical variables (vocabulary size = 5)
        self.embedding = nn.Embedding(5, self.env.args.embed_dim)

        # Fully connected layer for feature processing
        self.fc = nn.Linear(self.env.args.embed_dim, self.env.args.embed_dim)

        # Bidirectional LSTM encoder for sequence processing
        self.encoder_lstm = nn.LSTM(
            input_size=self.env.args.embed_dim,
            hidden_size=self.env.args.embed_dim,
            num_layers=self.env.args.encoder_num_layers,
            batch_first=True,
            bidirectional=True,
        )

        # LSTM decoder for output generation
        self.decoder_lstm = nn.LSTM(
            input_size=self.env.args.embed_dim,
            hidden_size=self.env.args.embed_dim,
            num_layers=self.env.args.decoder_num_layers,
            batch_first=True,
        )

        # Output layer for binary classification (redundant, will be overwritten)
        self.output_layer = nn.Linear(self.env.args.embed_dim, 2)

        # Output layers for different prediction tasks (overwrites the previous output_layer)
        self.output_layer = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(self.env.args.embed_dim, num_classes),
                )
                for num_classes in self.env.feature_classes
            ]
        )

        # # Decision maker network for supply chain decisions
        # self.decision_maker = nn.Sequential( # may affect the initialization of the model
        #     nn.Linear(self.c_num, self.c_num), nn.ReLU(), nn.Linear(self.c_num, 4)
        # )

        # Move model to specified device (CPU/GPU)
        self.to(self.env.device)

    def forward(self, c_input, shipping_mode, tgt):
        # Document the forward method parameters and return values
        """
        Forward pass of the S_SimDec model.

        Args:
            c_input (torch.Tensor): Input features with shape (batch_size, feature_dim)
                Contains concatenated features for product, order, customer, and shipping info
            shipping_mode (torch.Tensor): Shipping mode information
                Shape: (batch_size,) for categorical indices or (batch_size, embed_dim) for embeddings
            tgt (torch.Tensor): Target labels for training
                Shape: (batch_size, sequence_length) for sequence prediction tasks

        Returns:
            tuple: (outputs, decision_outputs)
                - outputs (list): List of model outputs for each prediction task
                  Each output has shape (batch_size, sequence_length, num_classes)
                - decision_outputs (torch.Tensor): Decision maker network outputs
                  Shape: (batch_size, 4) for supply chain decisions
        """
        # Transform input features for different components (product, order, customer, shipping)
        c_out4p = self.c_transform4p(
            c_input[:, : len(feature_list.product_info[self.env.args.dataset])]
        )
        c_out4o = self.c_transform4o(
            c_input[
                :,
                len(feature_list.product_info[self.env.args.dataset]) : len(
                    feature_list.product_info[self.env.args.dataset]
                )
                + len(feature_list.order_info[self.env.args.dataset]),
            ]
        )
        c_out4c = self.c_transform4c(
            c_input[
                :,
                len(feature_list.product_info[self.env.args.dataset])
                + len(feature_list.order_info[self.env.args.dataset]) : len(
                    feature_list.product_info[self.env.args.dataset]
                )
                + len(feature_list.order_info[self.env.args.dataset])
                + len(feature_list.customer_info[self.env.args.dataset]),
            ]
        )
        c_out4s = self.c_transform4s(
            c_input[:, -len(feature_list.shipping_info[self.env.args.dataset]) :]
        )

        # Concatenate transformed features along sequence dimension
        c_out = torch.cat((c_out4p.unsqueeze(1), c_out4o.unsqueeze(1)), dim=1)
        c_out = torch.cat((c_out, c_out4c.unsqueeze(1)), dim=1)
        c_out = torch.cat((c_out, c_out4s.unsqueeze(1)), dim=1)

        # Create pooled features by averaging across batch dimension
        pooled_features = torch.mean(
            c_input[:, : self.c_num], dim=0, keepdim=True
        ).repeat(c_input.shape[0], 1)
        pooled_features = self.pooling_fc(pooled_features)

        # Handle shipping mode embedding (categorical or pre-embedded)
        if len(shipping_mode.shape) == 1:
            # Categorical indices - convert to embeddings
            sm_embed = self.embedding(shipping_mode).unsqueeze(1)
        else:
            # Pre-embedded features - just add sequence dimension
            sm_embed = shipping_mode.unsqueeze(1)

        # Combine all features and apply transformation
        combined = torch.cat((c_out, pooled_features.unsqueeze(1), sm_embed), dim=1)
        combined = torch.relu(self.fc(combined))

        # Encode combined features using bidirectional LSTM
        _, (h_n, c_n) = self.encoder_lstm(combined)

        # Combine forward and backward hidden states from bidirectional LSTM
        h_n_forward = h_n[0 : h_n.size(0) : 2]
        h_n_backward = h_n[1 : h_n.size(0) : 2]
        h_n_combined = h_n_forward + h_n_backward
        c_n_combined = c_n[0 : c_n.size(0) : 2] + c_n[1 : c_n.size(0) : 2]

        # Initialize decoder hidden state with combined encoder states
        decoder_hidden = (h_n_combined, c_n_combined)

        # Prepare for sequence generation
        batch_size = c_input.shape[0]
        SOS_token = torch.full((batch_size, 1), 1, dtype=torch.long).to(self.env.device)

        generated_tokens = []

        # Generate sequence tokens autoregressively
        for t in range(tgt.size(-1)):
            if t == 0:
                # Use start-of-sequence token for first step
                tgt_embed = self.embedding(SOS_token)
            else:
                # Use previous decoder output for subsequent steps
                tgt_embed = decoder_output

            # Decode current step
            decoder_output, decoder_hidden = self.decoder_lstm(
                tgt_embed, decoder_hidden
            )
            # Apply task-specific output layer for current position
            predicted_token = self.output_layer[t](decoder_output.squeeze(1))

            generated_tokens.append(predicted_token)

        return generated_tokens

    # def decision_process(self, c_input):
    #     # Process input through the decision maker network
    #     # The decision_maker expects input features and outputs decision probabilities
    #     decision_output = self.decision_maker(c_input)
    #     return decision_output
