from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Optional, Tuple

import torch


Tensor = torch.Tensor


@dataclass
class BasisOptions:
    method: str = "identity"                          
    dim: Optional[int] = None                              
    seed: Optional[int] = None
    reorth: bool = True


def _orthonormalize(mat: Tensor) -> Tensor:
                                           
    B, A, D = mat.shape
    out = torch.zeros_like(mat)
    for b in range(B):
        q, r = torch.linalg.qr(mat[b])
        out[b] = q[:, :D]
    return out


def horizontal_action_basis(actions: Tensor, opts: Optional[BasisOptions] = None) -> Tensor:
    if opts is None:
        opts = BasisOptions()
    B, A = actions.shape
    d = opts.dim or A
    if opts.method == "identity":
        basis = torch.eye(A, device=actions.device, dtype=actions.dtype).unsqueeze(0).expand(B, A, A)[..., :d]
        return basis
    if opts.method == "random":
        if opts.seed is not None:
            torch.manual_seed(int(opts.seed))
        M = torch.randn(B, A, d, device=actions.device, dtype=actions.dtype)
        return _orthonormalize(M) if opts.reorth else M
    if opts.method == "qr":
                                            
        M = torch.randn(B, A, d, device=actions.device, dtype=actions.dtype)
        return _orthonormalize(M)
    if opts.method == "svd":
                                                                   
        C = torch.einsum("ba,bb->ab", actions, actions)
                                                                                             
        return torch.eye(A, device=actions.device, dtype=actions.dtype).unsqueeze(0).expand(B, A, A)[..., :d]
    basis = torch.eye(A, device=actions.device, dtype=actions.dtype).unsqueeze(0).expand(B, A, A)[..., :d]
    return basis


def ensure_orthonormal(basis: Tensor, tol: float = 1e-5) -> bool:
    B, A, D = basis.shape
    ok = True
    for b in range(B):
        G = basis[b].T @ basis[b]
        I = torch.eye(D, device=basis.device, dtype=basis.dtype)
        if (G - I).abs().max().item() > tol:
            ok = False
            break
    return ok


def projector_from_basis(basis: Tensor) -> Tensor:
               
    return torch.einsum("bad,bae->bae", basis, basis.transpose(1, 2))


def project_to_horizontal(vec: Tensor, basis: Tensor) -> Tensor:
    return (basis @ (basis.transpose(1, 2) @ vec.unsqueeze(-1))).squeeze(-1)


def project_to_vertical(vec: Tensor, basis: Tensor) -> Tensor:
    B, A, _ = basis.shape
    I = torch.eye(A, device=basis.device, dtype=basis.dtype).unsqueeze(0).expand(B, A, A)
    P = projector_from_basis(basis)
    return ((I - P) @ vec.unsqueeze(-1)).squeeze(-1)


def decompose(vec: Tensor, basis: Tensor) -> Tuple[Tensor, Tensor]:
    h = project_to_horizontal(vec, basis)
    v = vec - h
    return h, v


def basis_from_mask(mask: Tensor, dim: Optional[int] = None) -> Tensor:
                                                    
    B, A = mask.shape
    d = dim or A
    U = torch.eye(A, device=mask.device, dtype=mask.dtype).unsqueeze(0).expand(B, A, A)
    return U[..., :d]


def verify_basis(basis: Tensor, tol: float = 1e-5) -> dict:
    B, A, D = basis.shape
    errs = []
    for b in range(B):
        G = basis[b].T @ basis[b]
        I = torch.eye(D, device=basis.device, dtype=basis.dtype)
        errs.append(float((G - I).abs().max().item()))
    return {"max_orth_err": max(errs) if errs else 0.0, "mean_orth_err": sum(errs) / max(1, len(errs))}


def random_horizontal_components(actions: Tensor, k: int) -> Tensor:
    B, A = actions.shape
    basis = horizontal_action_basis(actions, BasisOptions(method="random", dim=min(A, k)))
    return basis


def _demo():
    B, A = 4, 3
    a = torch.randn(B, A)
    U = horizontal_action_basis(a, BasisOptions(method="random", dim=A, reorth=True))
    v = torch.randn(B, A)
    h = project_to_horizontal(v, U)
    vert = project_to_vertical(v, U)
    stats = verify_basis(U)
    print(h.shape, vert.shape, stats)


if __name__ == "__main__":
    _demo()
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
