from models.FarSeg.farseg_boost import FarSeg_boost
import torch



x = torch.randn(2,3,896,896)
model = FarSeg_boost(in_ch=3, num_classes=16)
model(x)