"""Data transformation utilities.

This module provides functions for transforming and scaling target properties
and other data preprocessing operations.
"""

from __future__ import annotations
from typing import List
from dataclasses import dataclass
from pathlib import Path
import json
import numpy as np


def identity(smiles: List[str]) -> List[str]:
    """Identity transformation that returns input unchanged.

    Args:
        smiles: List of SMILES strings

    Returns:
        Same list of SMILES strings
    """
    return smiles


@dataclass
class TargetScaler:
    """Standardization scaler for target properties.

    Performs z-score normalization: (y - mean) / std
    Stores statistics for inverse transformation.
    Handles NaN values by ignoring them during fitting and preserving them during transforms.

    Attributes:
        mean: Mean values for each target property
        std: Standard deviation for each target property
    """

    mean: np.ndarray
    std: np.ndarray

    @classmethod
    def fit(cls, y: np.ndarray) -> "TargetScaler":
        """Fit scaler to target data, ignoring NaN values.

        Args:
            y: Target array of shape (n_samples, n_properties), may contain NaNs

        Returns:
            Fitted TargetScaler instance
        """
        if y.ndim == 1:
            y = y.reshape(-1, 1)

        # Use nanmean and nanstd to ignore NaN values
        mean = np.nanmean(y, axis=0)
        std = np.nanstd(y, axis=0)

        # Avoid division by zero
        std = np.where(std < 1e-8, 1.0, std)

        return cls(mean=mean, std=std)

    def transform(self, y: np.ndarray) -> np.ndarray:
        """Standardize target data.

        Args:
            y: Target array of shape (n_samples, n_properties)

        Returns:
            Standardized array with same shape as input
        """
        if y.ndim == 1:
            y = y.reshape(-1, 1)
        return (y - self.mean) / self.std

    def inverse_transform(self, y: np.ndarray) -> np.ndarray:
        """Reverse standardization.

        Args:
            y: Standardized array of shape (n_samples, n_properties)

        Returns:
            Original scale array with same shape as input
        """
        if y.ndim == 1:
            y = y.reshape(-1, 1)
        return y * self.std + self.mean

    def fit_transform(self, y: np.ndarray) -> tuple[np.ndarray, "TargetScaler"]:
        """Fit scaler and transform data in one step.

        Args:
            y: Target array of shape (n_samples, n_properties)

        Returns:
            Tuple of (transformed data, fitted scaler)
        """
        scaler = self.fit(y)
        y_scaled = scaler.transform(y)
        return y_scaled, scaler

    def save(self, path: str | Path) -> None:
        """Save scaler statistics to JSON file.

        Args:
            path: File path for saving
        """
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)

        data = {
            "mean": self.mean.tolist(),
            "std": self.std.tolist(),
        }

        with open(path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2)

    @classmethod
    def load(cls, path: str | Path) -> "TargetScaler":
        """Load scaler statistics from JSON file.

        Args:
            path: File path to load from

        Returns:
            TargetScaler instance with loaded statistics
        """
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)

        return cls(
            mean=np.array(data["mean"]),
            std=np.array(data["std"]),
        )
