"""
Utility functions for MDShortcut molecular dynamics diffusion models.

This module provides various utility functions for:
- Atomic position manipulation within periodic boundary conditions
- String parsing for slice operations
- Notification services for long-running experiments
- TensorBoard logging for training visualization
"""
import os
import requests

import torch
from torch.utils.tensorboard import SummaryWriter


def positions_into_cell(pos, cell):
    """Wrap atomic positions into the unit cell using periodic boundaries.
    
    Maps atomic positions to their equivalent positions within the unit cell
    by applying periodic boundary conditions.
    
    Args:
        pos (torch.Tensor): Atomic positions, shape (n_atoms, 3).
        cell (torch.Tensor): Unit cell matrix, shape (3, 3).
        
    Returns:
        torch.Tensor: Wrapped positions within the unit cell.
    """
    invlat = torch.linalg.inv(cell)
    relpos = pos @ invlat
    relpos = relpos % 1.0
    pos = relpos @ cell
    return pos


def string2slice(s):
    """Convert string representation to Python slice object.
    
    Parses slice notation strings like ':10', '5:', '1:10:2' into
    Python slice objects for array indexing.
    
    Args:
        s (str): String representation of slice (e.g., ':', ':10', '5:15').
        
    Returns:
        slice or int: Python slice object or integer for single index.
    """
    if ':' in s:
        return slice(*map(lambda x: int(x) if x else None, s.split(':')))
    else:
        return int(s)


def send_notification(title, message, priority=5):
    """Send notification via external service for experiment monitoring.
    
    Sends notifications about experiment progress, completion, or errors
    using environment-configured notification service.
    
    Args:
        title (str): Notification title/subject.
        message (str): Notification message body.
        priority (int, optional): Notification priority level. Defaults to 5.
        
    Returns:
        dict: Response from notification service.
        
    Note:
        Requires NOTIFY_URL and NOTIFY_TOKEN environment variables to be set.
    """
    data = {
        "title": title,
        "priority": priority,
        "message": message
    }
    response = requests.post(f'{os.getenv("NOTIFY_URL")}?token={os.getenv("NOTIFY_TOKEN")}', json=data)

    return response.json()


class TensorBoardLogger:
    """TensorBoard logger for visualizing training metrics.
    
    Wraps TensorBoard SummaryWriter to provide convenient logging of scalars,
    histograms, and model parameters during training and inference.
    
    Attributes:
        writer (SummaryWriter): TensorBoard writer instance.
    """
    
    def __init__(self, log_dir):
        """Initialize TensorBoard logger.
        
        Args:
            log_dir (str): Directory path for TensorBoard logs.
        """
        # Create tensorboard directory if it doesn't exist
        os.makedirs(log_dir, exist_ok=True)
        self.writer = SummaryWriter(log_dir=log_dir)

    def log_scalar(self, tag, value, step):
        """Log a scalar value.
        
        Args:
            tag (str): Name of the scalar metric.
            value (float): Scalar value to log.
            step (int): Step/iteration number.
        """
        self.writer.add_scalar(tag, value, step)

    def log_scalars(self, main_tag, tag_scalar_dict, step):
        """Log multiple scalars under the same main tag.
        
        Args:
            main_tag (str): Main category name for the scalars.
            tag_scalar_dict (dict): Dictionary mapping scalar names to values.
            step (int): Step/iteration number.
        """
        self.writer.add_scalars(main_tag, tag_scalar_dict, step)

    def log_histogram(self, tag, values, step):
        """Log histogram of values.
        
        Args:
            tag (str): Name of the histogram.
            values (torch.Tensor): Tensor values to create histogram from.
            step (int): Step/iteration number.
        """
        self.writer.add_histogram(tag, values, step)

    def log_model_params(self, model, step):
        """Log histograms of model parameters and gradients.
        
        Args:
            model (nn.Module): PyTorch model to log parameters from.
            step (int): Step/iteration number.
        """
        for name, param in model.named_parameters():
            if param.requires_grad:
                # Skip empty tensors
                if param.data.numel() > 0:
                    try:
                        self.log_histogram(f"params/{name}", param.data, step)
                        if param.grad is not None and param.grad.numel() > 0:
                            self.log_histogram(f"grads/{name}", param.grad, step)
                    except ValueError:
                        pass

    def close(self):
        """Close the TensorBoard writer and flush remaining data."""
        self.writer.close()
