"""
Unified tensor API for torch and numpy backends
"""

import torch 
import numpy as np


class TensorAPI:
    """Unified API for tensor operations supporting both PyTorch and NumPy"""
    
    backend = None

    @staticmethod
    def detect_backend(tensor):
        """Automatically detect backend from tensor type"""
        if isinstance(tensor, torch.Tensor):
            TensorAPI.backend = torch
        elif isinstance(tensor, np.ndarray):
            TensorAPI.backend = np
        else:
            raise ValueError(f"Unknown tensor type: {type(tensor)}")

    @staticmethod
    def device(tensor):
        """Get device of tensor"""
        if TensorAPI.backend is torch:
            return tensor.device
        elif TensorAPI.backend is np:
            return "cpu"
        else:
            raise ValueError(f"Unknown backend {TensorAPI.backend}")

    @staticmethod   
    def zeros(shape, device=None, **kwargs):
        """Create zero tensor"""
        if TensorAPI.backend is torch:
            return torch.zeros(shape, device=device, **kwargs)
        elif TensorAPI.backend is np:
            return np.zeros(shape, **kwargs)
        else:
            raise ValueError(f"Unknown backend {TensorAPI.backend}")
    
    @staticmethod
    def dim(tensor):
        """Get number of dimensions"""
        if TensorAPI.backend is torch:
            return tensor.dim()
        elif TensorAPI.backend is np:
            return len(tensor.shape)
        else:
            raise ValueError(f"Unknown backend {TensorAPI.backend}")
    
    @staticmethod
    def einsum(eq, *operands):
        """Einstein summation"""
        if TensorAPI.backend is torch:
            return torch.einsum(eq, *operands)
        elif TensorAPI.backend is np:
            return np.einsum(eq, *operands)
        else:
            raise ValueError(f"Unknown backend {TensorAPI.backend}")
    
    @staticmethod
    def det(tensor):
        """Matrix determinant"""
        if TensorAPI.backend is torch:
            return torch.det(tensor)
        elif TensorAPI.backend is np:
            return np.linalg.det(tensor)
        else:
            raise ValueError(f"Unknown backend {TensorAPI.backend}")
    
    @staticmethod
    def abs(tensor):
        """Absolute value"""
        if TensorAPI.backend is torch:
            return torch.abs(tensor)
        elif TensorAPI.backend is np:
            return np.abs(tensor)
        else:
            raise ValueError(f"Unknown backend {TensorAPI.backend}")
    
    @staticmethod
    def inv(tensor):
        """Matrix inverse"""
        if TensorAPI.backend is torch:
            return torch.inverse(tensor)
        elif TensorAPI.backend is np:
            return np.linalg.inv(tensor)
        else:
            raise ValueError(f"Unknown backend {TensorAPI.backend}")


# For backward compatibility
API = TensorAPI