import torch
import torch.nn as nn
import lightning.pytorch as pl

class CameraPoseGenerator(pl.LightningModule):
    def __init__(self, mode, number_of_views, input_channel, hidden_size):
        super().__init__()
        self.mode = mode
        self.number_of_views = number_of_views
        if self.mode == 'size':
            self.pose_encoder = nn.Sequential(
                nn.Linear(3, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, number_of_views * 3)
            )
    
    def forward(self, feature):
        if self.mode == 'size':
            pose_offset = self.pose_encoder(feature.float())
            pose_offset = pose_offset.reshape(self.number_of_views, 3)
        return pose_offset