import torch
import torch.nn as nn
import torch.nn.functional as F

class Ensemble(nn.Module):
    def __init__(self, models, num_classes):
        super().__init__()
        self.num_models = len(models)
        for i, model in enumerate(models):
            setattr(self, f"model{i}", model)
        self.num_classes = num_classes
    
    def forward(self, x):
        batch_size = x.size(0)
        output = torch.zeros(batch_size, self.num_classes).cuda()
        for i in range(self.num_models):
            model = getattr(self, f"model{i}")
            output = output + model(x)
        output = output / self.num_models
        return output 
