from spaghettini import quick_register

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Linear


@quick_register
class FullSort(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.sort(x, dim=1)[0]


if __name__ == "__main__":
    """
    # Run from root. 
    python -m src.dl.models.activations
    """
    tests = ["test_fullsort"]
    if "test_fullsort" in tests:
        # Check sorting dimension.
        sort_layer = FullSort()

        a = torch.zeros((3, 4))
        a.requires_grad = True
        torch.nn.init.normal(a)
        b = sort_layer(a)

        print(f"unsorted: \n {a}")
        print(f"sorted: \n {b}")

        # Check differentiability.
        c = b.sum()
        c.backward()
        # Derivative of c wrt. a
        print(f"Derivative of c wrt. a: {a.grad}")
