import torch, math
from thop import profile, clever_format
from racnn import RaCNN, Attention


def count_attention_cell(m: Attention, x: torch.Tensor, y: torch.Tensor):
    B, G, H, N, d = x[0].shape
    total_ops = B * G * H * N * d * N
    total_ops += B * G * H * N * N * 4
    total_ops += B * G * H * N * N * d
    m.total_ops += total_ops    


if __name__=="__main__":
    custom_ops = {Attention: count_attention_cell}
    input = torch.randn(1, 3, 225, 225)

    model = RaCNN(dims=[24,48,96,192], layers=[2,3,8,2], mlp_ratio=8.0, split_sizes=[8,4,2,1], num_heads=[2,4,8,16], drop_path_rate=0.00)
    # model = RaCNN(dims=[32,64,128,256], layers=[3,5,8,3], mlp_ratio=6.0, split_sizes=[8,4,2,1], num_heads=[2,4,8,16], drop_path_rate=0.05)
    # model = RaCNN(dims=[48,96,192,384], layers=[3,5,10,3], mlp_ratio=4.0, split_sizes=[8,4,2,1], num_heads=[2,4,8,16], drop_path_rate=0.10)
    # model = RaCNN(dims=[64,128,256,512], layers=[3,6,14,3], mlp_ratio=3.0, split_sizes=[8,4,2,1], num_heads=[2,4,8,16], drop_path_rate=0.20)
    # model = RaCNN(dims=[96,192,384,768], layers=[4,8,16,4], mlp_ratio=2.0, split_sizes=[8,4,2,1], num_heads=[2,4,8,16], drop_path_rate=0.40)

    model.eval()
    print(model)
    
    macs, params = profile(model, inputs=(input, ), custom_ops=custom_ops)
    macs, params = clever_format([macs, params], "%.3f")
    
    params = sum(p.numel() for p in model.parameters()) / 1e6
    print('Flops:  ', macs)
    print('Params: ', params)


