import torch
import numpy as np
import os
import json

class MaxScaler:
    """Scales data by dividing by maximum absolute value across spatial dimension"""
    
    def __init__(self):
        self.scale = None
    
    def fit(self, data):
        """Compute scaling factor as maximum absolute value across spatial dimension"""
        # data shape: (batch, channel, space)
        # Compute max abs value across batch and space dimensions
        self.scale = torch.max(torch.abs(data)).item()
        # Prevent division by zero
        self.scale = max(self.scale, 1e-8)
        return self
    
    def transform(self, data):
        """Scale the data"""
        if self.scale is None:
            raise ValueError("Scaler has not been fitted yet.")
        return data / self.scale
    
    def inverse_transform(self, data):
        """Unscale the data"""
        if self.scale is None:
            raise ValueError("Scaler has not been fitted yet.")
        return data * self.scale
    
    def fit_transform(self, data):
        """Fit and transform the data"""
        return self.fit(data).transform(data)
    
    def to(self, device):
        """Move scaler to device (no-op since scale is a scalar)"""
        return self
    
    def save(self, path):
        """Save scaling factor"""
        os.makedirs(os.path.dirname(path), exist_ok=True)
        stats = {'scale': self.scale}
        with open(path, 'w') as f:
            json.dump(stats, f)
    
    def load(self, path):
        """Load scaling factor"""
        with open(path, 'r') as f:
            stats = json.load(f)
        self.scale = stats['scale']
        return self