import torch
import torch.nn as nn

class IncrementalClassifier(nn.Module):
    def __init__(self,
        d_model: int,
        n_output_classes: int
    ):
        super().__init__()
        self.d_model = d_model
        self.n_output_classes = n_output_classes
        self.classification_head = nn.Linear(self.d_model, self.n_output_classes)
    
    def forward(self, hidden_state: torch.Tensor):
        return self.classification_head(hidden_state)
    
    def adaptation(self, n_output_classes: int):
        if self.n_output_classes == n_output_classes:
            return
        
        old_weight, old_bias = self.classification_head.weight.data, self.classification_head.bias.data
        self.classification_head = nn.Linear(self.d_model, n_output_classes).to(old_weight.device)
        self.classification_head.weight.data[:self.n_output_classes] = old_weight
        self.classification_head.bias.data[:self.n_output_classes] = old_bias
        self.n_output_classes = n_output_classes