import os
import re
from typing import Optional
import pathlib

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.max_val = 0
        self.min_val = 999999
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        """
        Update the average meter with a new value.

        This method updates the running statistics (sum, count, average, max, min)
        with a new value. It's commonly used to track metrics during training
        or evaluation phases.

        Args:
            val (float): The new value to be added to the statistics
            n (int, optional): The number of times this value occurs.
                             Defaults to 1. Useful for batch processing where
                             a single value represents multiple samples.
        """
        self.max_val = max(self.max_val, val)
        self.min_val = min(self.min_val, val)
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


import numpy as np
import torch
from typing import Tuple, Dict, List, Union
from sklearn.preprocessing import StandardScaler
import faiss


def batch_query_cost_dic(cost_dic, avg_profit, retrieve_index, states, actions, device):
    """
    Batch lookup for profit using cost_dic and faiss nearest neighbor search.
    For each (state, action) pair, construct a query vector and search for the nearest neighbor in cost_dic.
    If an exact match is found, use the corresponding profit; otherwise, use avg_profit for the action.

    Args:
        cost_dic: torch.Tensor or np.ndarray, shape (N, D+1), last column is profit
        avg_profit: list or np.ndarray, shape (action_dim,)
        retrieve_index: list/tuple of int, indices of features to use for retrieval
        states: torch.Tensor, shape (batch, feature_dim)
        actions: torch.Tensor, shape (batch,)
        device: torch.device
    Returns:
        profit: torch.Tensor, shape (batch,)
    """
    # Construct query vectors from states and actions
    batch = states.shape[0]
    query_vectors = np.array(
        [
            [
                states[i, retrieve_index[0]].item(),
                states[i, retrieve_index[1]].item(),
                actions[i].item(),
            ]
            for i in range(batch)
        ],
        dtype="float32",
    )
    # Prepare cost_dic data and profit
    cost_dic_data = cost_dic[:, :-1]
    cost_dic_y = cost_dic[:, -1]
    # Ensure numpy array for faiss
    x = cost_dic_data.cpu().numpy() if hasattr(cost_dic_data, "cpu") else cost_dic_data
    index = faiss.IndexFlatL2(x.shape[1])
    # faiss expects numpy arrays for both add and search
    # x: (N, D), query_vectors: (batch, D)
    # Add cost_dic data to index
    index.add(np.ascontiguousarray(x.astype(np.float32)))
    # Search for nearest neighbor for each query vector
    _, nearest_indices = index.search(
        np.ascontiguousarray(query_vectors.astype(np.float32)), 1
    )
    nearest_samples = x[nearest_indices.flatten()]
    # Check for exact match
    matches = np.all(query_vectors == nearest_samples, axis=1)
    avg_profit_np = np.array(avg_profit)
    selected_y = np.where(
        matches,
        cost_dic_y[nearest_indices.flatten()],
        avg_profit_np[actions.cpu().numpy()],
    )
    # Return as torch tensor
    profit = torch.tensor(selected_y, device=device, dtype=torch.float32)
    return profit


def get_feature_dim(env_args, feature_list) -> int:
    """
    Get feature dimension based on dataset configuration
    
    Args:
        env_args: Environment arguments containing dataset information
        feature_list: Feature list configuration
        
    Returns:
        int: Feature dimension size
    """
    return len(
        feature_list.product_info[env_args.dataset]
        + feature_list.order_info[env_args.dataset]
        + feature_list.customer_info[env_args.dataset]
        + feature_list.shipping_info[env_args.dataset]
    )


def get_label_dim(env_args, feature_list) -> int:
    """
    Get label dimension based on dataset configuration
    
    Args:
        env_args: Environment arguments containing dataset information
        feature_list: Feature list configuration
        
    Returns:
        int: Label dimension size
    """
    return len(feature_list.label[env_args.dataset])


def prepare_input_data(
    input_id: torch.Tensor, 
    mode: str, 
    scaler: StandardScaler, 
    device: torch.device,
    feature_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
    """
    Prepare input data including preprocessing and device transfer

    Args:
        input_id: Original input data
        mode: Processing mode ('train', 'val', 'test', 'ori')
        scaler: StandardScaler instance for normalization
        device: Target device for tensor transfer
        feature_dim: Feature dimension

    Returns:
        tuple: (ori_input, processed_input_id, feature_dim)
    """
    # Transfer data to device
    ori_input = input_id.to(device)

    # Apply standardization based on mode (skip for original data)
    if mode != "ori":
        # Use fit_transform for training, transform for validation/test
        input_id = (
            scaler.fit_transform(input_id)
            if mode == "train"
            else scaler.transform(input_id)
        )

    # Convert to tensor and transfer to device
    processed_input_id = torch.FloatTensor(input_id).to(device)

    return ori_input, processed_input_id, feature_dim


def select_dataset_by_mode(mode: str, val_inputs: torch.Tensor, test_inputs: torch.Tensor) -> torch.Tensor:
    """
    Select dataset based on mode
    
    Args:
        mode: Mode ('val', 'test', 'ori')
        val_inputs: Validation inputs
        test_inputs: Test inputs
        
    Returns:
        torch.Tensor: Selected dataset
    """
    if mode == "val" or mode == "ori":
        return val_inputs
    else:
        return test_inputs


def init_faiss_index(cost_dic: torch.Tensor) -> Tuple[faiss.IndexFlatL2, torch.Tensor, torch.Tensor]:
    """
    Initialize FAISS index for nearest neighbor search

    Args:
        cost_dic: Cost dictionary tensor

    Returns:
        tuple: (index, cost_dic_data, cost_dic_y)
    """
    # Extract cost dictionary data (features and targets)
    cost_dic_data = cost_dic[:, :-1]  # All columns except last
    cost_dic_y = cost_dic[:, -1]  # Last column (target)

    # Create FAISS index for L2 distance search
    index = faiss.IndexFlatL2(cost_dic_data.shape[1])

    # Convert tensor to numpy and add to index
    cost_data_np = cost_dic_data.cpu().numpy().astype("float32")
    index.add(cost_data_np)

    return index, cost_dic_data, cost_dic_y


def calculate_profit_percentiles(profits: List[float]) -> Dict[float, float]:
    """
    Calculate profit percentiles for analysis

    Args:
        profits: List of profits

    Returns:
        dict: Profit values at different percentiles
    """
    # Sort profits for percentile calculation
    sorted_profits = np.sort(profits)
    thresholds = [0.1, 0.2, 0.3]  # 10th, 20th, 30th percentiles
    profit_min_percent: Dict[float, float] = {}

    # Calculate percentile values
    for threshold in thresholds:
        idx = int(threshold * len(sorted_profits))
        profit_min_percent[threshold] = sorted_profits[idx]

    return profit_min_percent


def log_training_progress(
    epoch: int,
    total_epochs: int,
    metrics: Dict[str, float],
    train_time: float,
    use_calibration: bool = False,
    wandb_enabled: bool = False,
    wandb_epoch_offset: int = 0
) -> None:
    """
    Log training progress with appropriate formatting
    
    Args:
        epoch: Current epoch
        total_epochs: Total number of epochs
        metrics: Dictionary of metrics
        train_time: Training time
        use_calibration: Whether calibration is used
        wandb_enabled: Whether wandb logging is enabled
        wandb_epoch_offset: Offset for wandb epoch logging
    """
    from tools.logger import info
    
    info("-" * 50)

    if use_calibration:
        info(
            f'TRAIN:epoch = {epoch}/{total_epochs} classification_loss = {metrics["classification_loss"]:.5f} calibrated_loss = {metrics["calibrated_loss"]:.5f} train_time = {train_time:.2f}'
        )
        if wandb_enabled:
            import wandb
            wandb.log(
                {"loss/classification_loss": metrics["classification_loss"]},
                epoch + wandb_epoch_offset,
            )
            wandb.log(
                {"loss/calibrated_loss": metrics["calibrated_loss"]}, 
                epoch + wandb_epoch_offset
            )
    else:
        if "classification_loss" in metrics:
            # Predictor session logging
            info(
                f'TRAIN:epoch = {epoch}/{total_epochs} classification_loss = {metrics["classification_loss"]:.5f} train_time = {train_time:.2f}'
            )
            if wandb_enabled:
                import wandb
                wandb.log(
                    {"loss/classification_loss": metrics["classification_loss"]},
                    epoch + wandb_epoch_offset,
                )
        else:
            # Decision maker session logging
            info(
                f'TRAIN:epoch = {epoch}/{total_epochs} loss = {metrics["loss"]:.5f} profit_loss = {metrics["profit_loss"]:.5f} late_loss = {metrics["late_loss"]:.5f} train_time = {train_time:.2f}'
            )
            info(
                f'profit = {metrics["profit"]:.5f} on_time = {metrics["on_time"]:.5f} mi_loss = {metrics["mi_loss"]:.5f} ma_loss = {metrics["ma_loss"]:.5f} group_adv = {metrics["group_adv"]:.5f}'
            )

            if wandb_enabled:
                import wandb
                for key, value in metrics.items():
                    wandb.log({f"loss/{key}": value}, epoch + wandb_epoch_offset)


def log_validation_results(
    accuracies: Union[List[float], Tuple[float, float, Dict[float, float]]],
    val_time: float,
    epoch: int = None,
    phase: str = "sim",
    env_args = None,
    feature_list = None,
    wandb_enabled: bool = False,
    wandb_epoch_offset: int = 0
) -> None:
    """
    Log validation results with appropriate formatting
    
    Args:
        accuracies: List of accuracies or tuple of metrics
        val_time: Validation time
        epoch: Current epoch
        phase: Validation phase ('sim' or 'dm')
        env_args: Environment arguments
        feature_list: Feature list configuration
        wandb_enabled: Whether wandb logging is enabled
        wandb_epoch_offset: Offset for wandb epoch logging
    """
    from tools.logger import info
    
    info("-" * 10)

    if phase == "sim":
        # Log simulator validation results
        for i, accuracy in enumerate(accuracies):
            info(
                f"{feature_list.label[env_args.dataset][i]} Accuracy: {accuracy * 100:.2f}% val_time = {val_time:.2f}"
            )
            if wandb_enabled and epoch is not None:
                import wandb
                wandb.log(
                    {
                        f"eval/{feature_list.label[env_args.dataset][i]}": accuracy
                    },
                    epoch + wandb_epoch_offset,
                )
    else:  # dm phase
        # Log decision maker validation results
        profit, on_time_ratio, profit_min_percent = accuracies
        info(
            f"avg_profit = {profit:.5f} on_time_ratio = {on_time_ratio:.5f} overall = {profit+on_time_ratio:.5f} val_time = {val_time:.2f}"
        )
        info(
            f"profit_min_percent_10 = {profit_min_percent[0.1]:.5f} profit_min_percent_20 = {profit_min_percent[0.2]:.5f} profit_min_percent_30 = {profit_min_percent[0.3]:.5f}"
        )

        if wandb_enabled and epoch is not None:
            import wandb
            wandb.log(
                {f"eval/avg_profit": profit, "eval/on_time_ratio": on_time_ratio},
                epoch + wandb_epoch_offset,
            )
            wandb.log(
                {
                    f"eval/profit_min_percent_10": profit_min_percent[0.1],
                    "eval/profit_min_percent_20": profit_min_percent[0.2],
                    "eval/profit_min_percent_30": profit_min_percent[0.3],
                },
                epoch + wandb_epoch_offset,
            )


def save_model_checkpoint(
    model: torch.nn.Module,
    path: str,
    current_epoch: int,
    best_epoch: int,
    ckpt_path: str,
    suffix: str,
    model_type: str = ""
) -> None:
    """
    Save model checkpoint with epoch information
    
    Args:
        model: Model to save
        path: Path for saving
        current_epoch: Current epoch number
        best_epoch: Best epoch number
        ckpt_path: Checkpoint directory path
        suffix: Model suffix
        model_type: Type of model ('dm' for decision maker)
    """
    import os
    
    model_state_file = os.path.join(
        ckpt_path, f"{suffix}_{model_type}_epoch{current_epoch}.pth"
    )
    torch.save(model.state_dict(), path)
    
    if current_epoch != best_epoch:
        old_model_state_file = os.path.join(
            ckpt_path, f"{suffix}_{model_type}_epoch{best_epoch}.pth"
        )
        if os.path.exists(old_model_state_file):
            os.system("rm {}".format(old_model_state_file))
            
            
import os
import re
from datetime import datetime
from typing import Optional, Dict

class ModelNameParser:
    """
    Utility class for generating and parsing model checkpoint filenames.

    Filename format:
        {date}_{dataset}_epoch{epoch}_{mode}_{use_flag}_{eta}.pth

    Components:
        - date (str): Date of the run in format MM-DD-YY (e.g., "07-26-20")
        - dataset (str): Dataset name (e.g., "OAS", "dataco")
        - epoch (int): Epoch number (e.g., 41)
        - mode (str): Either 'sim' (simulator) or 'dm' (decision-maker)
        - use_flag (bool): Indicates whether calibration or perturbation was used
        - eta (float): Value for calibration/perturbation (e.g., 0.1)
    """

    @staticmethod
    def generate_name(
        date: str,
        dataset: str,
        epoch: int,
        mode: str,
        use_flag: bool,
        eta: float,
        folder: Optional[pathlib.Path] = None
    ) -> pathlib.Path:
        """
        Generate a standardized model filename using pathlib for cross-platform compatibility.

        Args:
            date (str): Date string in format "MM-DD-YY"
            dataset (str): Dataset name (e.g., "OAS")
            epoch (int): Epoch number (e.g., 41)
            mode (str): Either 'sim' or 'dm'
            use_flag (bool): Whether calibration or perturbation was used
            eta (float): Calibration/Perturbation value
            folder (Optional[Path]): Optional directory as a pathlib.Path or str

        Returns:
            Path: A pathlib.Path object representing the full file path

        Example:
            >>> ModelNameParser.generate_name("07-26-20", "OAS", 41, "sim", True, 0.1)
            PosixPath('07-26-20_OAS_epoch41_sim_True_0.1.pth')

            >>> ModelNameParser.generate_name("07-26-20", "OAS", 41, "sim", True, 0.1, folder=Path("./ckpt"))
            PosixPath('ckpt/07-26-20_OAS_epoch41_sim_True_0.1.pth')
        """
        filename = f"{date}_{dataset}_epoch{epoch}_{mode}_{use_flag}_{eta}.pth"
        if folder is not None:
            folder = pathlib.Path(folder)
            return folder / filename
        return pathlib.Path(filename)


    @staticmethod
    def parse_name(filename: str) -> Dict:
        """
        Parse a filename into its components.

        Args:
            filename (str): Model checkpoint filename (can include path)

        Returns:
            dict: Dictionary with keys:
                - date (str)
                - dataset (str)
                - epoch (int)
                - mode (str)
                - use_flag (bool)
                - eta (float)
                - suffix (str): "{date}_{dataset}"

        Raises:
            ValueError: If filename does not match expected pattern

        Example:
            >>> ModelNameParser.parse_name("07-26-20_OAS_epoch41_sim_True_0.1.pth")
            {
                'date': '07-26-20',
                'dataset': 'OAS',
                'epoch': 41,
                'mode': 'sim',
                'use_flag': True,
                'eta': 0.1,
                'suffix': '07-26-20_OAS'
            }
        """
        basename = os.path.basename(filename)
        pattern = (
            r'(?P<date>\d{2}-\d{2}-\d{2})_(?P<dataset>[A-Za-z0-9]+)'
            r'_epoch(?P<epoch>\d+?)_(?P<mode>sim|dm)_(?P<use_flag>True|False)_(?P<eta>[\d.]+)\.pth'
        )
        match = re.match(pattern, basename)
        if not match:
            raise ValueError(f"Filename '{filename}' does not match expected format.")
        return {
            "date": match.group("date"),
            "dataset": match.group("dataset"),
            "epoch": int(match.group("epoch")),
            "mode": match.group("mode"),
            "use_flag": match.group("use_flag") == "True",
            "eta": float(match.group("eta")),
            "suffix": f"{match.group('date')}_{match.group('dataset')}"
        }

    @staticmethod
    def find_latest_model(
        folder: str,
        dataset: str,
        mode: str,
        use_flag: bool,
        eta: float,
        date: Optional[str] = None
    ) -> Optional[Dict]:
        """
        Find the model with the largest epoch for a given dataset and condition.
        If date is not provided, the most recent date will be used.

        Args:
            folder (str): Directory containing model files
            dataset (str): Dataset name (e.g., "OAS")
            mode (str): 'sim' or 'dm'
            use_flag (bool): True if calibration or perturbation is used
            eta (float): Eta value used during training
            date (Optional[str]): Date string ("MM-DD-YY") or None for most recent

        Returns:
            dict: Parsed metadata of the matched model file, or None if not found

        Example:
            >>> ModelNameParser.find_latest_model(
                    folder="./ckpt",
                    dataset="OAS",
                    mode="sim",
                    use_flag=True,
                    eta=0.1
                )
            {
                'date': '07-26-20',
                'dataset': 'OAS',
                'epoch': 41,
                'mode': 'sim',
                'use_flag': True,
                'eta': 0.1,
                'suffix': '07-26-20_OAS'
            }
        """
        candidates = []
        for fname in os.listdir(folder):
            try:
                parsed = ModelNameParser.parse_name(fname)
                if (parsed["dataset"] == dataset and
                    parsed["mode"] == mode and
                    parsed["use_flag"] == use_flag and
                    abs(parsed["eta"] - eta) < 1e-6):

                    if date is None or parsed["date"] == date:
                        candidates.append((parsed["date"], parsed["epoch"], fname))
            except Exception:
                continue

        if not candidates:
            return None

        # Sort by (date, epoch)
        candidates.sort(
            key=lambda x: (datetime.strptime(x[0], "%m-%d-%y"), x[1]),
            reverse=True
        )

        # Use the most recent date + highest epoch
        _, _, best_fname = candidates[0]
        return ModelNameParser.parse_name(best_fname)

if __name__ == "__main__":
    # Generate model filename
    name = ModelNameParser.generate_name("07-26-20", "OAS", 41, "sim", True, 0.1, folder="./exp_report/OAS/ckpt")
    print(name)
    # → "07-26-20_OAS_epoch41_sim_True_0.1.pth"

    # Parse existing filename
    parsed = ModelNameParser.parse_name("07-26-20_OAS_epoch41_sim_True_0.1.pth")
    print(parsed)

    # Automatically find latest model (latest date + max epoch)
    latest = ModelNameParser.find_latest_model(
        folder="./exp_report/OAS/ckpt",
        dataset="OAS",
        mode="sim",
        use_flag=True,
        eta=0.1
    )
    print(latest)
