import torch, math
from thop import profile, clever_format
from sfcnn import SFCNN

if __name__=="__main__":
    custom_ops = {}
    input = torch.randn(1, 3, 224, 224)

    # model = SFCNN(dims=[32,64,128,256], layers=[4,8,14,4], group_widths=1, mlp_ratio=3.0, act_type="silu", s_act_type="gsilu", use_le=True, drop_path_rate=0.00)
    # model = SFCNN(dims=[40,80,160,320], layers=[4,8,24,4], group_widths=1, mlp_ratio=3.0, act_type="silu", s_act_type="gsilu", use_le=True, drop_path_rate=0.05)
    # model = SFCNN(dims=[48,96,192,384], layers=[6,12,24,6], group_widths=1, mlp_ratio=3.0, act_type="silu", s_act_type="gsilu", use_le=True, drop_path_rate=0.10)
    # model = SFCNN(dims=[64,128,256,512], layers=[6,12,28,6], group_widths=1, mlp_ratio=3.0, act_type="silu", s_act_type="gsilu", use_le=True, drop_path_rate=0.20)
    # model = SFCNN(dims=[80,160,320,640], layers=[8,15,35,8], group_widths=1, mlp_ratio=3.0, act_type="silu", s_act_type="gsilu", use_le=True, drop_path_rate=0.35)

    model = SFCNN(dims=[48,96,192,384], layers=[4,8,20,4], group_widths=1, mlp_ratio=4.0, act_type="silu", s_act_type="gsilu", use_le=True, drop_path_rate=0.10)

    model.eval()
    print(model)
    
    macs, params = profile(model, inputs=(input, ), custom_ops=custom_ops)
    macs, params = clever_format([macs, params], "%.3f")
    
    params = sum(p.numel() for p in model.parameters()) / 1e6
    print('Flops:  ', macs)
    print('Params: ', params)

