import torch


class Adapter:
    def __init__(self):
        self.shape: torch.Size = None

    def adapt(self, x: torch.Tensor) -> torch.Tensor:
        return x

    def reconstruct(self, x: torch.Tensor, match: torch.Tensor) -> torch.Tensor:
        return x, match


class ViTAdapter(Adapter):
    def __init__(self):
        super().__init__()

    def adapt(self, x: torch.Tensor) -> torch.Tensor:
        self.cls_token = x[:1]
        return x[1:]

    def reconstruct(self, x: torch.Tensor, match: torch.Tensor) -> torch.Tensor:
        x = torch.cat((self.cls_token, x), dim=0)
        return x, match


__REGISTERED_ADAPTERS__ = {
    "vit": ViTAdapter
}


def get_adapter(name: str) -> Adapter:
    return __REGISTERED_ADAPTERS__[name]()

