import argparse
import torch
import torch.nn as nn
from thop import profile
from model import NetworkCIFAR
import genotypes

def compute_flops(model, input_shape=(1, 3, 32, 32), device="cpu"):
    model.eval().to(device)
    dummy = torch.randn(*input_shape, device=device)
    macs, _ = profile(model, inputs=(dummy,), verbose=False)
    flops = macs * 2
    return flops / max(1, input_shape[0])

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--init_channels", type=int, default=32)
    parser.add_argument("--layers", type=int, default=10)
    parser.add_argument("--height", type=int, default=32)
    parser.add_argument("--width", type=int, default=32)
    parser.add_argument("--genotype", type=str, default="search_cifar10_338_0")
    args = parser.parse_args()

    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    genotype = getattr(genotypes, args.genotype)

    model = NetworkCIFAR(args.init_channels, 10, args.layers, False, genotype)
    model.drop_path_prob = 0.0

    flops = compute_flops(model, input_shape=(1, 3, args.height, args.width), device=device)
    print(f"{int(flops)}")

if __name__ == "__main__":
    main()
