import sys, pathlib

sys.path.append(str(pathlib.Path(__file__).parent.parent))
import os
import time
import random
import faiss
import torch
import wandb
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Union, Any

# import faiss.contrib.torch_utils
from torch import autograd
from tqdm import tqdm
from tools.utils import AverageMeter
from tools import feature_list
from evaluations.metric import (
    compute_rec_loss,
    compute_error_rates,
    weighted_label_smoothing_loss,
    loss_function,
    focal_loss,
)
from torch.utils.data import DataLoader
from tools.logger import info
from sklearn.preprocessing import StandardScaler
import torch.nn.functional as F
from models.recalibrator import Recalibrator
from models.v_model import ValueNetwork
from models.s_model import S_SimDec
from loaders.s_loader import S_Loader
from models.perturbator import Perturbator
from tools.utils import batch_query_cost_dic


class CB_Session(object):
    def __init__(self, env: Any, model: S_SimDec, dataset: S_Loader) -> None:
        """
        CB_Session: Context-Based Session for Supply Chain Optimization

        A session class that manages the training and evaluation of the S_SimDec model
        for supply chain optimization tasks. This class handles both simulator training
        and decision-maker training phases, implementing various loss functions and
        optimization strategies.

        The session supports:
        - Simulator training with multiple loss components (reconstruction, classification)
        - Decision-maker training with reinforcement learning components
        - Early stopping and model checkpointing
        - Performance tracking and evaluation metrics
        - Value network integration for RL-based decision making

        Args:
            env: Environment object containing configuration and dataset information
            model: S_SimDec model instance for training
            dataset: DataLoader object containing training/validation/test data

        Attributes:
            env: Environment configuration and dataset information
            model: S_SimDec model for training
            value_network: Value network for reinforcement learning
            optimizer: Optimizer for simulator training
            optimizer_dm: Optimizer for decision-maker training
            loader: DataLoader for batch processing
            train_inputs: Training input data
            val_inputs: Validation input data
            test_inputs: Test input data
            action_dim: Number of possible actions for decision making
            epsilon: Exploration rate for epsilon-greedy policy
            early_stop: Early stopping counter
            best_epoch: Best epoch for simulator training
            best_dm_epoch: Best epoch for decision-maker training
            total_epoch: Total number of training epochs
            best_overall_accuracy: Best overall accuracy achieved
            best_acc1, best_acc2, best_acc3: Best accuracies for different tasks
            best_dm_accuracy: Best decision-maker accuracy
            cost_dic: Cost dictionary for MRP calculations
            avg_profit: Average profit for normalization
            test_rec_loss: Test reconstruction loss
            scaler: StandardScaler for feature normalization
            best_p, best_o: Best profit and on-time delivery metrics
            best_pmp1, best_pmp2, best_pmp3: Best performance metrics
            min_profit, max_profit: Profit range for normalization
            min_on_time, max_on_time: On-time delivery range for normalization
        """
        self.env = env
        self.model = model
        self.value_network = None
        self.optimizer_dm = None
        self.loader = DataLoader(
            dataset, batch_size=self.env.args.batch_size, shuffle=True
        )
        self.train_inputs = dataset.train_inputs
        self.val_inputs = dataset.val_inputs
        self.test_inputs = dataset.test_inputs

        self.action_dim = 4
        self.epsilon = 0.1

        self.early_stop = 0
        self.best_epoch = 0
        self.best_dm_epoch = 0
        self.total_epoch = 0
        self.best_overall_accuracy = 0
        self.best_acc1, self.best_acc2, self.best_acc3 = 0, 0, 0
        self.best_dm_accuracy = 0
        self.cost_dic = dataset.cost_mrp
        self.avg_profit = dataset.avg_profit
        self.test_rec_loss = 99999
        self.scaler = StandardScaler()
        self.best_p = 0
        self.best_o = 0
        self.best_pmp1 = 0
        self.best_pmp2 = 0
        self.best_pmp3 = 0

        self.min_profit, self.max_profit = float("inf"), float("-inf")
        self.min_on_time, self.max_on_time = float("inf"), float("-inf")
        self.use_calibration = self.env.args.use_calibration
        self.use_perturbation = self.env.args.use_perturbation
        self.feature_dim = len(
            feature_list.product_info[self.env.args.dataset]
            + feature_list.order_info[self.env.args.dataset]
            + feature_list.customer_info[self.env.args.dataset]
            + feature_list.shipping_info[self.env.args.dataset]
        )
        self.label_dim = len(feature_list.label[self.env.args.dataset])
        self.use_perturbation = getattr(self.env.args, "use_perturbation", False)
        self.perturbator = None
        param_groups = [
            {
                "params": filter(lambda p: p.requires_grad, self.model.parameters()),
                "lr": self.env.args.lr,
            }
        ]

        # 如果使用校准器，添加其参数
        if self.env.args.use_calibration:
            self.recalibrator = Recalibrator(
                input_dim=self.feature_dim, K=self.action_dim
            )
            param_groups.append(
                {
                    "params": filter(
                        lambda p: p.requires_grad, self.recalibrator.parameters()
                    ),
                    "lr": self.env.args.lr,
                }
            )

        self.optimizer = torch.optim.Adam(
            param_groups, weight_decay=self.env.args.decay_coeff
        )

    # ==================== Utility Methods ====================

    def _get_feature_dim(self) -> int:
        """
        Get feature dimension based on dataset configuration

        Returns:
            int: Feature dimension size
        """
        return len(
            feature_list.product_info[self.env.args.dataset]
            + feature_list.order_info[self.env.args.dataset]
            + feature_list.customer_info[self.env.args.dataset]
            + feature_list.shipping_info[self.env.args.dataset]
        )

    def _get_label_dim(self) -> int:
        """
        Get label dimension based on dataset configuration

        Returns:
            int: Label dimension size
        """
        return len(feature_list.label[self.env.args.dataset])

    def _prepare_input_data(
        self, input_id: torch.Tensor, mode: str = "train"
    ) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Prepare input data including preprocessing and device transfer

        Args:
            input_id: Original input data
            mode: Processing mode ('train', 'val', 'test', 'ori')

        Returns:
            tuple: (ori_input, processed_input_id, feature_dim)
                - ori_input: Original input data (transferred to device)
                - processed_input_id: Preprocessed input data
                - feature_dim: Feature dimension
        """
        # Transfer data to device
        ori_input = input_id.to(self.env.device)

        # Apply standardization based on mode (skip for original data)
        if mode != "ori":
            # Use fit_transform for training, transform for validation/test
            input_id = (
                self.scaler.fit_transform(input_id)
                if mode == "train"
                else self.scaler.transform(input_id)
            )

        # Convert to tensor and transfer to device
        processed_input_id = torch.FloatTensor(input_id).to(self.env.device)
        feature_dim = self._get_feature_dim()

        return ori_input, processed_input_id, feature_dim

    def _init_faiss_index(self) -> faiss.IndexFlatL2:
        """
        Initialize FAISS index for nearest neighbor search

        Returns:
            faiss.IndexFlatL2: Initialized FAISS index
        """
        # Extract cost dictionary data (features and targets)
        self.cost_dic_data = self.cost_dic[:, :-1]  # All columns except last
        self.cost_dic_y = self.cost_dic[:, -1]  # Last column (target)

        # Create FAISS index for L2 distance search
        index = faiss.IndexFlatL2(self.cost_dic_data.shape[1])

        # Convert tensor to numpy and add to index
        cost_data_np = self.cost_dic_data.cpu().numpy().astype("float32")
        index.add(cost_data_np)

        return index

    def _make_decision(
        self, state: torch.Tensor, mode: str = "value_network"
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Make decisions based on specified mode

        Args:
            state: State features
            mode: Decision mode ('value_network', 'ori')

        Returns:
            tuple: (decision_prob, decision_prob_value, action)
                - decision_prob: Decision probabilities
                - decision_prob_value: Decision probability values
                - action: Selected actions
        """
        if mode == "ori":
            # Use original decision indices from data
            decision_indices = state[:, self.feature_dim].long()
            decision_prob = (
                F.one_hot(decision_indices, num_classes=4).float().to(self.env.device)
            )
            decision_prob_value = decision_prob
        else:
            # Get value network output and convert to decision probabilities
            assert self.value_network is not None, ValueError(
                "Value network not initialized"
            )
            value_network_output = self.value_network(state)
            decision_prob_value = (
                F.softmax(value_network_output, dim=1)
                == F.softmax(value_network_output, dim=1)
                .max(dim=1, keepdim=True)
                .values
            ).float()
            decision_prob = decision_prob_value

        # Extract action indices from decision probabilities
        action = decision_prob_value.argmax(dim=1).squeeze()

        return decision_prob, decision_prob_value, action

    @staticmethod
    def _calculate_profit_percentiles(profits: List[float]) -> Dict[float, float]:
        """
        Calculate profit percentiles for analysis

        Args:
            profits: List of profits

        Returns:
            dict: Profit values at different percentiles
        """
        # Sort profits for percentile calculation
        sorted_profits = np.sort(profits)
        thresholds = [0.1, 0.2, 0.3]  # 10th, 20th, 30th percentiles
        profit_min_percent: Dict[float, float] = {}

        # Calculate percentile values
        for threshold in thresholds:
            idx = int(threshold * len(sorted_profits))
            profit_min_percent[threshold] = sorted_profits[idx]

        return profit_min_percent

    def _log_training_progress(
        self,
        epoch: int,
        total_epochs: int,
        metrics: Dict[str, float],
        train_time: float,
        phase: str = "sim",
    ) -> None:
        """
        Log training progress with appropriate formatting

        Args:
            epoch: Current epoch
            total_epochs: Total number of epochs
            metrics: Dictionary of metrics
            train_time: Training time
            phase: Training phase ('sim' or 'dm')
        """
        info("-" * 50)

        if phase == "sim":
            # Log simulator training progress
            if self.use_calibration:
                info(
                    f'TRAIN:epoch = {epoch}/{total_epochs} classification_loss = {metrics["classification_loss"]:.5f} calibrated_loss = {metrics["calibrated_loss"]:.5f} train_time = {train_time:.2f}'
                )
                if self.env.args.wandb:
                    wandb.log(
                        {"loss/classification_loss": metrics["classification_loss"]},
                        epoch,
                    )
                    wandb.log(
                        {"loss/calibrated_loss": metrics["calibrated_loss"]}, epoch
                    )
            else:
                info(
                    f'TRAIN:epoch = {epoch}/{total_epochs} classification_loss = {metrics["classification_loss"]:.5f} train_time = {train_time:.2f}'
                )
                if self.env.args.wandb:
                    wandb.log(
                        {"loss/classification_loss": metrics["classification_loss"]},
                        epoch,
                    )
        else:  # dm phase
            # Log decision maker training progress
            info(
                f'TRAIN:epoch = {epoch}/{total_epochs} loss = {metrics["loss"]:.5f} profit_loss = {metrics["profit_loss"]:.5f} late_loss = {metrics["late_loss"]:.5f} train_time = {train_time:.2f}'
            )
            info(
                f'profit = {metrics["profit"]:.5f} on_time = {metrics["on_time"]:.5f} mi_loss = {metrics["mi_loss"]:.5f} ma_loss = {metrics["ma_loss"]:.5f} group_adv = {metrics["group_adv"]:.5f}'
            )

            if self.env.args.wandb:
                for key, value in metrics.items():
                    wandb.log({f"loss/{key}": value}, self.env.args.epochs + 1 + epoch)

    def _log_validation_results(
        self,
        accuracies: Union[List[float], Tuple[float, float, Dict[float, float]]],
        val_time: float,
        epoch: Optional[int] = None,
        phase: str = "sim",
    ) -> None:
        """
        Log validation results with appropriate formatting

        Args:
            accuracies: List of accuracies or metrics
            val_time: Validation time
            epoch: Current epoch
            phase: Validation phase ('sim' or 'dm')
        """
        info("-" * 10)

        if phase == "sim":
            # Log simulator validation results
            for i, accuracy in enumerate(accuracies):
                info(
                    f"{feature_list.label[self.env.args.dataset][i]} Accuracy: {accuracy * 100:.2f}% val_time = {val_time:.2f}"
                )
                if self.env.args.wandb and epoch is not None:
                    wandb.log(
                        {
                            f"eval/{feature_list.label[self.env.args.dataset][i]}": accuracy
                        },
                        epoch,
                    )
        else:  # dm phase
            # Log decision maker validation results
            profit, on_time_ratio, profit_min_percent = accuracies
            info(
                f"avg_profit = {profit:.5f} on_time_ratio = {on_time_ratio:.5f} overall = {profit+on_time_ratio:.5f} val_time = {val_time:.2f}"
            )
            info(
                f"profit_min_percent_10 = {profit_min_percent[0.1]:.5f} profit_min_percent_20 = {profit_min_percent[0.2]:.5f} profit_min_percent_30 = {profit_min_percent[0.3]:.5f}"
            )

            if self.env.args.wandb and epoch is not None:
                wandb.log(
                    {f"eval/avg_profit": profit, "eval/on_time_ratio": on_time_ratio},
                    self.env.args.epochs + 1 + epoch,
                )
                wandb.log(
                    {
                        f"eval/profit_min_percent_10": profit_min_percent[0.1],
                        "eval/profit_min_percent_20": profit_min_percent[0.2],
                        "eval/profit_min_percent_30": profit_min_percent[0.3],
                    },
                    self.env.args.epochs + 1 + epoch,
                )

    def _check_and_save_best_model(self, current_metrics, epoch, phase="sim"):
        """
        Check if current model is best and save if necessary

        Args:
            current_metrics: Current performance metrics
            epoch: Current epoch
            phase: Training phase ('sim' or 'dm')

        Returns:
            bool: Whether a better model was found
        """
        if phase == "sim":
            # Check simulator model performance
            current_accuracy = (
                sum(current_metrics) / len(current_metrics) if current_metrics else 0
            )
            if current_accuracy > self.best_overall_accuracy:
                # Update best performance metrics
                self.best_overall_accuracy = current_accuracy
                self.best_acc1 = current_metrics[0]
                self.best_acc2 = current_metrics[1]
                self.best_acc3 = current_metrics[2]

                if self.env.args.wandb:
                    wandb.log(
                        {f"eval/best_overall_accuracy": self.best_overall_accuracy},
                        epoch,
                    )

                info(f"best_overall_accuracy: {self.best_overall_accuracy * 100:.2f}% ")
                self.early_stop = 0  # Reset early stopping counter

                if self.env.args.save:
                    self.save_model(epoch, "sim")

                self.best_epoch = epoch
                return True
        else:  # dm phase
            # Check decision maker model performance
            profit, on_time_ratio, _ = current_metrics
            current_score = on_time_ratio + profit
            if current_score > self.best_dm_accuracy:
                # Update best decision maker metrics
                self.best_dm_accuracy = current_score
                self.best_p = profit
                self.best_o = on_time_ratio
                self.best_pmp1 = current_metrics[2][0.1]
                self.best_pmp2 = current_metrics[2][0.2]
                self.best_pmp3 = current_metrics[2][0.3]

                if self.env.args.wandb:
                    wandb.log(
                        {f"eval/best_on_time_ratio": on_time_ratio},
                        self.env.args.epochs + 1 + epoch,
                    )
                    wandb.log(
                        {f"eval/best_profit": profit}, self.env.args.epochs + 1 + epoch
                    )
                    wandb.log(
                        {f"eval/best_dm_accuracy": self.best_dm_accuracy},
                        self.env.args.epochs + 1 + epoch,
                    )

                info(f"best_dm_accuracy: {self.best_dm_accuracy:.5f} ")
                self.early_stop = 0  # Reset early stopping counter

                if self.env.args.save:
                    self.save_model(self.env.args.epochs + 1 + epoch, "dm")

                self.best_dm_epoch = self.env.args.epochs + 1 + epoch
                return True

        return False

    def _select_dataset_by_mode(self, mode):
        """
        Select dataset based on mode

        Args:
            mode: Mode ('val', 'test', 'ori')

        Returns:
            torch.Tensor: Selected dataset
        """
        if mode == "val" or mode == "ori":
            return self.val_inputs
        else:
            return self.test_inputs

    def init_value_network(self, value_network: ValueNetwork):
        """
        Initialize the value network for decision-maker training.

        This method sets up the value network and its optimizer for the decision-maker
        training phase. The value network is used to estimate the expected future rewards
        for different actions in the reinforcement learning framework.

        Args:
            value_network (ValueNetwork): The value network model to be initialized

        Note:
            - The value network is moved to the appropriate device (CPU/GPU)
            - The optimizer is configured with the decision-maker learning rate
            - Weight decay is applied for regularization
        """
        self.value_network = value_network
        self.optimizer_dm = torch.optim.Adam(
            [{"params": self.value_network.parameters(), "lr": self.env.args.dm_lr}],
            weight_decay=self.env.args.dm_decay_coeff,
        )
        if self.use_perturbation:
            self.perturbator = Perturbator(
                predictor=self.model,
                policy=self.value_network,
                env=self.env,
                M=getattr(self.env.args, "perturb_M", 8),
                device=self.env.device,
                cost_dic=self.cost_dic,
                avg_profit=self.avg_profit,
                feature_list=feature_list,
                otr_reward_coeff=getattr(self.env.args, "otr_reward_coeff", 1.0),
                retrieve_index=feature_list.retrieve_index[self.env.args.dataset],
                action_dim=self.action_dim,
            )

    def train_epoch(self):
        """
        Train simulator model for one epoch

        Returns:
            tuple: Training loss and time
        """
        t = time.time()
        self.model.train()
        self.total_epoch += 1

        # Initialize meters for tracking losses
        all_classification_loss = AverageMeter()
        all_calibrated_loss = AverageMeter()

        for input_id in tqdm(self.loader):
            # Prepare data using utility method
            ori_input, processed_input_id, feature_dim = self._prepare_input_data(
                input_id, mode="train"
            )

            # Forward pass through the model
            predicted_tokens = self.model(
                c_input=processed_input_id[:, :feature_dim],
                shipping_mode=ori_input[:, feature_dim].long(),
                tgt=ori_input[:, feature_dim + 1 :].long(),
            )

            # Calculate classification loss for each label dimension
            total_loss = 0
            for i in range(self.label_dim):
                logits = predicted_tokens[i]
                y_true = ori_input[:, feature_dim + 1 + i]
                classification_loss = torch.nn.CrossEntropyLoss()(logits, y_true.long())
                total_loss += classification_loss

            # Average loss across all label dimensions
            loss = total_loss / self.label_dim
            all_classification_loss.update(loss.item(), processed_input_id.size(0))

            # Apply calibration if enabled
            if self.use_calibration:
                total_calibration_loss = 0
                for i in range(self.label_dim):
                    logits = predicted_tokens[i]  # [B, C]
                    y_true = ori_input[:, feature_dim + 1 + i]
                    y_true_onehot = torch.nn.functional.one_hot(
                        y_true.long(), num_classes=logits.shape[1]
                    ).float()

                    # Get softmax probabilities
                    p_prev = torch.softmax(logits, dim=-1)

                    # Apply recalibration to improve probability estimates
                    p_calibrated, _ = self.recalibrator(
                        x=processed_input_id[:, :feature_dim],
                        p_prev=p_prev,
                        y_true=y_true_onehot,
                    )  # [B, C]

                    # Calculate calibration loss and add to total loss
                    calibration_loss = torch.nn.functional.cross_entropy(
                        p_calibrated, y_true.long()
                    )
                    all_calibrated_loss.update(
                        calibration_loss.item(), processed_input_id.size(0)
                    )
                    total_calibration_loss += calibration_loss
                loss += self.env.args.eta_c * (total_calibration_loss / self.label_dim)

            # Backpropagation and parameter update
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(), 1
            )  # Gradient clipping
            self.optimizer.step()

        # Return appropriate metrics based on calibration usage
        if self.use_calibration:
            return all_classification_loss.avg, all_calibrated_loss.avg, time.time() - t
        else:
            return all_classification_loss.avg, time.time() - t

    def train(self):
        """
        Train the simulation model using the training data.

        This method orchestrates the training process for the simulation model by:
        1. Initializing training parameters and early stopping mechanism
        2. Running training epochs with forward/backward passes
        3. Evaluating model performance on validation data
        4. Tracking best performance metrics and saving checkpoints
        5. Implementing early stopping to prevent overfitting

        The training process includes:
        - Classification loss optimization for multiple label dimensions
        - Regular validation evaluation at specified intervals
        - Performance tracking for individual label accuracies
        - Best model checkpoint saving when performance improves
        - Early stopping when no improvement is observed

        Returns:
            None: Training results are stored in instance variables
        """
        # Initialize early stopping counter
        self.early_stop = 0

        # Main training loop over epochs
        for epoch in range(self.env.args.ckpt_start_epoch, self.env.args.epochs):
            # Train for one epoch and get loss and time
            if self.use_calibration:
                classification_loss, calibrated_loss, train_time = self.train_epoch()
                metrics = {
                    "classification_loss": classification_loss,
                    "calibrated_loss": calibrated_loss,
                }
            else:
                classification_loss, train_time = self.train_epoch()
                metrics = {"classification_loss": classification_loss}

            # Log training progress using utility method
            self._log_training_progress(
                epoch, self.env.args.epochs, metrics, train_time, phase="sim"
            )

            # Run validation test
            self.test("val")

            # Evaluate model performance at specified intervals
            if epoch % self.env.args.eva_interval == 0:
                # Increment early stopping counter
                self.early_stop += 1

                # Get validation accuracies and time
                accuracies, val_time = self.test("val")

                # Log validation results using utility method
                self._log_validation_results(accuracies, val_time, epoch, phase="sim")

                # Check and save best model using utility method
                self._check_and_save_best_model(accuracies, epoch, phase="sim")

            # Early stopping: break if no improvement for specified number of epochs
            if self.early_stop > self.env.args.early_stop:
                break

    def dm_train_epoch(self):
        """
        Train decision maker for one epoch using advanced RL approach with Perturbator support

        Returns:
            tuple: Average values of various losses and metrics, plus training time
        """
        t = time.time()
        self.model.train()
        if self.value_network is not None:
            self.value_network.train()
        self.total_epoch += 1

        # Initialize meters for tracking various metrics
        all_mi_loss = AverageMeter()
        all_ma_loss = AverageMeter()
        all_profit = AverageMeter()
        all_on_time = AverageMeter()
        all_profit_loss = AverageMeter()
        all_late_loss = AverageMeter()
        all_loss = AverageMeter()
        all_group_adv_loss = AverageMeter()

        # Get feature dimension using utility method
        feature_dim = self._get_feature_dim()

        # Initialize FAISS index using utility method
        self._init_faiss_index()

        # Freeze simulator model parameters during decision maker training
        for param in self.model.parameters():
            param.requires_grad = False

        for input_id in tqdm(self.loader):
            # Prepare data using utility method
            ori_input, processed_input_id, _ = self._prepare_input_data(
                input_id, mode="train"
            )
            state = processed_input_id[:, :feature_dim]

            # Generate decision probabilities using value network
            assert self.value_network is not None, ValueError(
                "Value network not initialized"
            )

            # Get value network output and apply softmax
            decision_prob_value = F.softmax(self.value_network(state), dim=1)
            # Use Gumbel-Softmax for discrete action selection
            decision_prob = F.gumbel_softmax(decision_prob_value, tau=1, hard=True)

            # Select shipping mode embedding based on decision
            selected_embedding = torch.sum(
                decision_prob.unsqueeze(2) * self.model.embedding.weight[:4, :], dim=1
            )

            # Forward pass through simulator with selected embedding
            predicted_tokens = self.model(
                c_input=processed_input_id[:, :feature_dim],
                shipping_mode=selected_embedding,
                tgt=ori_input[:, feature_dim + 1 :],
            )

            # Calculate profit-based weights for decision making
            profits = torch.tensor(self.avg_profit)
            weights = profits / profits.max()

            # Set target classes for profit and on-time prediction
            target_class_decision = (
                torch.argmax(profits).expand(decision_prob.size(0)).to(self.env.device)
            )
            decision_weights = weights.to(self.env.device)

            # Calculate profit loss (encourages selection of high-profit actions)
            profit_loss = F.cross_entropy(
                decision_prob, target_class_decision, weight=decision_weights
            )

            # Calculate late delivery loss (encourages on-time delivery prediction)
            target_class_predicted = torch.ones(
                predicted_tokens[-1].size(0), dtype=torch.long
            ).to(self.env.device)
            late_loss = F.cross_entropy(predicted_tokens[-1], target_class_predicted)

            # Combine profit and late delivery losses
            mi_loss = (
                self.env.args.mip_coeff * profit_loss
                + self.env.args.mil_coeff * late_loss
            )

            # Add Perturbator group-wise perturbation loss if enabled
            group_adv_loss = 0.0
            if self.use_perturbation and self.perturbator is not None:
                # Get perturbation output for adversarial training
                perturb_out = self.perturbator.forward(state)
                group_adv_loss_val = perturb_out["group_adv_loss"]

                # Convert to float for loss calculation
                if (
                    isinstance(group_adv_loss_val, torch.Tensor)
                    and group_adv_loss_val.dim() == 0
                ):
                    group_adv_loss = float(group_adv_loss_val.item())
                elif isinstance(group_adv_loss_val, (int, float)):
                    group_adv_loss = float(group_adv_loss_val)

                # Add group adversarial loss to main loss
                mi_loss = mi_loss - self.env.args.eta_p * group_adv_loss
                all_group_adv_loss.update(group_adv_loss, processed_input_id.size(0))

            # Update loss meters
            all_profit_loss.update(self.env.args.mip_coeff * profit_loss)
            all_late_loss.update(self.env.args.mil_coeff * late_loss)

            # Extract actions from decision probabilities
            action = decision_prob.argmax(dim=1).squeeze()

            # Get profit for actions using utility method
            # selected_y = self._get_profit_for_actions(ori_input, action)
            selected_y = batch_query_cost_dic(
                self.cost_dic,
                self.avg_profit,
                feature_list.retrieve_index[self.env.args.dataset],
                ori_input[:, : self.feature_dim],
                action,
                self.env.device,
            )

            # Convert actions to one-hot encoding for aggregation
            one_hot_action = F.one_hot(action, num_classes=self.action_dim).float()
            action_profit_sum = torch.matmul(
                one_hot_action.T, selected_y.unsqueeze(1)
            ).squeeze(1)
            action_profit_count = one_hot_action.sum(dim=0)

            # Calculate on-time delivery metrics per action
            on_time = predicted_tokens[-1].argmax(dim=1)
            action_on_time_sum = torch.matmul(
                one_hot_action.T, on_time.unsqueeze(1).float()
            ).squeeze(1)
            action_on_time_count = action_profit_count

            # Calculate average metrics per action
            avg_profit_per_action = action_profit_sum / (action_profit_count + 1e-8)
            avg_on_time_per_action = action_on_time_sum / (action_on_time_count + 1e-8)

            # Combine profit and on-time metrics into reward
            reward_per_action = (
                avg_profit_per_action
                + getattr(self.env.args, "otr_reward_coeff", 1.0)
                * avg_on_time_per_action
            )

            # Prepare for backpropagation
            assert self.optimizer_dm is not None, ValueError(
                "Optimizer not initialized"
            )
            self.optimizer_dm.zero_grad()

            # Initialize smoothed reward if not exists
            if not hasattr(self, "smoothed_reward"):
                self.smoothed_reward = reward_per_action

            # Apply reward smoothing for stability
            reward_smoothing_factor = getattr(
                self.env.args, "reward_smoothing_factor", 0.9
            )
            self.smoothed_reward = (
                reward_smoothing_factor * reward_per_action
                + (1 - reward_smoothing_factor) * self.smoothed_reward
            )

            # Predict rewards using value network
            predicted_rewards = self.value_network(state).mean(dim=0)

            # Calculate mean squared error loss between predicted and actual rewards
            ma_loss = F.mse_loss(predicted_rewards, self.smoothed_reward)

            # Combine all losses with configurable coefficients
            mi_coeff = getattr(self.env.args, "mi_coeff", 1.0)
            ma_coeff = getattr(self.env.args, "ma_coeff", 1.0)
            loss = mi_coeff * mi_loss + ma_coeff * ma_loss

            # Backpropagate and update parameters
            loss.backward()
            self.optimizer_dm.step()

            # Update metric trackers
            all_mi_loss.update(mi_coeff * mi_loss)
            all_ma_loss.update(ma_coeff * ma_loss)
            all_loss.update(loss, len(processed_input_id))

        return (
            all_loss.avg,
            all_profit_loss.avg,
            all_late_loss.avg,
            all_profit.avg,
            all_on_time.avg,
            all_mi_loss.avg,
            all_ma_loss.avg,
            all_group_adv_loss.avg,
            time.time() - t,
        )

    def dm_train(self):
        """
        Train the decision maker model using reinforcement learning.

        This method orchestrates the training process for the decision maker component
        of the supply chain optimization model. The decision maker learns to make optimal
        decisions that maximize profit while maintaining high on-time delivery rates.

        Training Process:
        1. Initializes early stopping mechanism
        2. Runs training epochs with forward/backward passes
        3. Evaluates model performance on validation data
        4. Tracks best performance metrics and saves checkpoints
        5. Implements early stopping to prevent overfitting

        The training includes:
        - Multi-objective loss optimization (profit + on-time delivery)
        - Regular validation evaluation at specified intervals
        - Performance tracking for profit and on-time delivery metrics
        - Best model checkpoint saving when performance improves
        - Early stopping when no improvement is observed

        Returns:
            None: Training results are stored in instance variables
        """
        self.early_stop = 0
        for epoch in range(self.env.args.dm_epochs):
            (
                loss,
                profit_loss,
                late_loss,
                profit_r,
                on_time_r,
                mi_loss,
                ma_loss,
                group_adv_loss,
                train_time,
            ) = self.dm_train_epoch()

            # Log training progress using utility method
            metrics = {
                "loss": loss,
                "profit_loss": profit_loss,
                "late_loss": late_loss,
                "profit": profit_r,
                "on_time": on_time_r,
                "mi_loss": mi_loss,
                "ma_loss": ma_loss,
                "group_adv": group_adv_loss,
            }
            self._log_training_progress(
                epoch, self.env.args.dm_epochs, metrics, train_time, phase="dm"
            )

            # Evaluate model performance at specified intervals
            if epoch % self.env.args.eva_interval == 0:
                self.early_stop += 1
                profit, on_time_ratio, profit_min_percent, val_time = self.dm_test(
                    "val"
                )

                # Log validation results using utility method
                validation_metrics = (profit, on_time_ratio, profit_min_percent)
                self._log_validation_results(
                    validation_metrics,
                    val_time,
                    self.env.args.epochs + 1 + epoch,
                    phase="dm",
                )

                # Check and save best model using utility method
                self._check_and_save_best_model(
                    validation_metrics, self.env.args.epochs + 1 + epoch, phase="dm"
                )

            if self.early_stop > self.env.args.early_stop:
                break

    def dm_test(self, mode, epsilon_p=0.0, random_noise=False, return_detail=False):
        """
        Decision Maker Test

        This method evaluates the decision maker component of the supply chain optimization model.
        It tests the model's ability to make optimal decisions and predict outcomes based on
        the current state and selected actions.

        Args:
            mode (str): Test mode - 'val' for validation, 'test' for test set, 'ori' for original data
            epsilon_p (float): Perturbation strength for testing robustness
            random_noise (bool): Whether to use Gaussian noise as perturbation
            return_detail (bool): Whether to return per-sample detail (action, on_time, profit)

        Returns:
            tuple: (profit, on_time_ratio, profit_min_percent, test_time [, detail_df])
                - profit: Average profit achieved
                - on_time_ratio: Ratio of on-time deliveries
                - profit_min_percent: Dictionary of minimum profit percentiles
                - test_time: Time taken for testing
                - detail_df (optional): DataFrame with columns ['index', 'action', 'on_time', 'profit']
        """
        # Set models to evaluation mode
        self.model.eval()
        assert self.value_network is not None, ValueError("Value network not initialized")
        self.value_network.eval()
        t = time.time()

        # Select dataset using utility method
        input_id = self._select_dataset_by_mode(mode)

        # Prepare data using utility method
        ori_input, processed_input_id, feature_dim = self._prepare_input_data(
            input_id, mode
        )

        # Initialize FAISS index using utility method
        self._init_faiss_index()

        # Initialize metrics tracking variables
        profit_sum = 0
        profit_count = 0
        time_sum = 0
        time_count = 0
        local_profits = []

        # 初始化用于记录每条数据的信息（如果需要）
        if return_detail:
            detail_records = []

        # Extract state features for decision making
        state = processed_input_id[:, :feature_dim]
        
        # === perturbation ===
        if random_noise:
            noise = torch.randn_like(state) * epsilon_p
            state += noise
        else:
            if self.perturbator is not None and epsilon_p > 0.0:
                with torch.no_grad():
                    z, sigma = self.perturbator.encode(state)
                    zs = self.perturbator.sample_perturbations(z, sigma, epsilon_p=epsilon_p)
                    # 只用第一个扰动样本
                    perturbed_state = self.perturbator.decode(zs[1])
                    perturbed_state = self.perturbator.inverse_transform(perturbed_state)
                    state = perturbed_state

        with torch.no_grad():
            # Make decisions using utility method
            decision_prob, decision_prob_value, action = self._make_decision(
                state, mode
            )

            # Get profit for actions using utility method
            # selected_y = self._get_profit_for_actions(ori_input, action)
            selected_y = batch_query_cost_dic(
                self.cost_dic,
                self.avg_profit,
                feature_list.retrieve_index[self.env.args.dataset],
                ori_input[:, : self.feature_dim],
                action,
                self.env.device,
            )

            # Calculate profits for each sample
            for idx in range(len(selected_y)):
                profit_sum += selected_y[idx]
                profit_count += 1
                local_profits.append(selected_y[idx].item())

            # Calculate profit percentiles using utility method
            profit_min_percent = self._calculate_profit_percentiles(local_profits)

            # Use decision to select embedding and predict outcomes with simulator
            selected_embedding = torch.sum(
                decision_prob.unsqueeze(2) * self.model.embedding.weight[:4, :], dim=1
            )

            # Forward pass through simulator
            predicted_tokens = self.model(
                c_input=processed_input_id[:, :feature_dim],
                shipping_mode=selected_embedding,
                tgt=ori_input[:, feature_dim + 1 :],
            )

            # Calculate on-time delivery ratio
            on_time_preds = predicted_tokens[-1].argmax(dim=1)
            time_sum += on_time_preds.sum().item()
            time_count += len(on_time_preds)

            # 记录每条数据的信息
            if return_detail:
                for idx in range(len(selected_y)):
                    detail_records.append({
                        "index": idx,
                        "action": action[idx].item(),
                        "on_time": on_time_preds[idx].item(),
                        "profit": selected_y[idx].item()
                    })

        # Calculate final metrics
        profit = profit_sum / profit_count if profit_count > 0 else 0
        on_time_ratio = time_sum / time_count if time_count > 0 else 0

        # 返回详细 DataFrame（如果启用）
        if return_detail:
            import pandas as pd
            detail_df = pd.DataFrame(detail_records)
            return profit, on_time_ratio, profit_min_percent, time.time() - t, detail_df
        else:
            return profit, on_time_ratio, profit_min_percent, time.time() - t


    def test(self, mode, use_b_weighted_error=False):
        """
        Test the model on validation or test data.

        Args:
            mode (str): Either 'val' for validation data or 'test' for test data
            use_b_weighted_error (bool): If True, compute b-weighted error using recalibrator

        Returns:
            tuple: (list of accuracies for each label dimension, evaluation time)
                or (list of b-weighted errors for each label dimension, evaluation time) if use_b_weighted_error is True
        """
        chunk_size = int(self.env.args.batch_size // 1.5)
        self.model.eval()
        t = time.time()

        # Select dataset using utility method
        input_id = self._select_dataset_by_mode(mode)

        # Prepare data using utility method
        ori_input, processed_input_id, feature_dim = self._prepare_input_data(
            input_id, mode
        )

        # Get label dimension using utility method
        label_dim = self._get_label_dim()

        # Initialize dictionaries to track label value distributions
        ori_label_value_counts = {j: {} for j in range(label_dim)}
        label_value_counts = {j: {} for j in range(label_dim)}

        # Initialize accuracy counters for each label dimension
        correct_preds = [0] * label_dim
        total_samples = [0] * label_dim

        # Initialize b-weighted accuracy accumulators if needed
        if use_b_weighted_error and getattr(self, "use_calibration", False):
            b_weighted_correct_sum = [
                0.0
            ] * label_dim  # Sum of b-weighted correct predictions
            b_weighted_total_weight = [0.0] * label_dim  # Sum of b weights

        with torch.no_grad():
            for i in range(0, len(processed_input_id), chunk_size):
                input_chunk = processed_input_id[i : i + chunk_size]
                ori_chunk = ori_input[i : i + chunk_size]

                # Count original label values for distribution analysis
                for j in range(label_dim):
                    for value in ori_chunk[:, feature_dim + j + 1].cpu().numpy():
                        if value not in ori_label_value_counts[j]:
                            ori_label_value_counts[j][value] = 0
                        ori_label_value_counts[j][value] += 1

                # Generate predictions using the model
                predicted_tokens = self.model(
                    c_input=input_chunk[:, :feature_dim],
                    shipping_mode=ori_chunk[:, feature_dim].long(),
                    tgt=ori_chunk[:, feature_dim + 1 :],
                )

                # Extract true class labels for comparison
                class_labels = ori_chunk[:, -label_dim:].long().to(self.env.device)

                for j in range(label_dim):
                    logits = predicted_tokens[j]
                    y_true = class_labels[:, j]
                    # Calculate accuracy for each label dimension
                    predicted = torch.argmax(logits, dim=1)
                    correct_preds[j] += int((predicted == y_true).sum().item())
                    total_samples[j] += len(y_true)
                    # Count predicted label values for distribution analysis
                    for value in predicted.cpu().numpy():
                        if value not in label_value_counts[j]:
                            label_value_counts[j][value] = 0
                        label_value_counts[j][value] += 1

                    # Calculate b-weighted accuracy if requested and recalibrator is used
                    if use_b_weighted_error and getattr(self, "use_calibration", False):
                        y_true_onehot = F.one_hot(
                            y_true, num_classes=logits.shape[1]
                        ).float()
                        p_prev = torch.softmax(logits, dim=-1)
                        # Call recalibrator to get b_xa
                        p_calibrated, info = self.recalibrator(
                            x=input_chunk[:, :feature_dim],
                            p_prev=p_prev,
                            y_true=y_true_onehot,
                        )
                        b_xa = info["b"]  # [B, K]
                        # Compute if prediction is correct for each sample
                        is_correct = (predicted == y_true).float()  # [B]
                        # b-weighted correct: sum over K directions for each sample
                        b_weighted_correct = (b_xa * is_correct.unsqueeze(1)).sum(
                            dim=1
                        )  # [B]
                        b_weighted_sum = b_xa.sum(dim=1)  # [B]
                        # Accumulate sums for this label
                        b_weighted_correct_sum[j] += b_weighted_correct.sum().item()
                        b_weighted_total_weight[j] += b_weighted_sum.sum().item()

        if use_b_weighted_error and getattr(self, "use_calibration", False):
            # Compute average b-weighted accuracy for each label
            b_weighted_accuracy = [
                (
                    b_weighted_correct_sum[j] / b_weighted_total_weight[j]
                    if b_weighted_total_weight[j] > 0
                    else 0.0
                )
                for j in range(label_dim)
            ]
            return b_weighted_accuracy, time.time() - t
        else:
            # Compute final accuracies for each label dimension
            accuracies = [correct_preds[j] / total_samples[j] for j in range(label_dim)]
            return accuracies, time.time() - t

    def save_ckpt(self, path, mode):
        if mode == "sim":
            torch.save(self.model.state_dict(), path)
        elif mode == "dm":
            torch.save(self.value_network.state_dict(), path)

    def save_model(self, current_epoch, mode):
        if mode == "sim":
            model_state_file = os.path.join(
                self.env.CKPT_PATH,
                f"{self.env.suffix}_epoch{current_epoch}_{mode}_{self.env.args.use_calibration}_{self.env.args.eta_c}.pth",
            )
        elif mode == "dm":
            model_state_file = os.path.join(
                self.env.CKPT_PATH,
                f"{self.env.suffix}_epoch{current_epoch}_{mode}_{self.env.args.use_perturbation}_{self.env.args.eta_p}.pth",
            )
        self.save_ckpt(model_state_file, mode)
        if mode == "sim":
            best_epoch = self.best_epoch
        else:
            best_epoch = self.best_dm_epoch
        if current_epoch != best_epoch:
            if mode == "sim":
                old_model_state_file = os.path.join(
                    self.env.CKPT_PATH,
                    f"{self.env.suffix}_epoch{best_epoch}_{mode}_{self.env.args.use_calibration}_{self.env.args.eta_c}.pth",
                )
            elif mode == "dm":
                old_model_state_file = os.path.join(
                    self.env.CKPT_PATH,
                    f"{self.env.suffix}_epoch{best_epoch}_{mode}_{self.env.args.use_perturbation}_{self.env.args.eta_p}.pth",
                )
            if os.path.exists(old_model_state_file):
                os.system("rm {}".format(old_model_state_file))
