from fvcore.nn import FlopCountAnalysis
import torch
from models import get_model
from utils import Config
import argparse
import yaml 

def measure_flops(config):
    device = f"cuda:{config.device}" if torch.cuda.is_available() else "cpu"
    print(f"Using {device}.")

    # Load model
    model_class = get_model(config.model.name)
    model = model_class(
        **config.model.params.to_dict()
    ).to(device)
    inputs = torch.randn(1, 5, 64, 64, 64).to(device)

    flops = FlopCountAnalysis(model, inputs)
    print(flops.total())


if __name__ == "__main__":
    # Initialize argument parser for command line interface
    parser = argparse.ArgumentParser()
    # Required: Path to the YAML configuration file
    parser.add_argument("--config", type=str, help="Path to config file.")
    parser.add_argument("--set", metavar="KEY=VAL", action="append",
                        help="Override any config entry, e.g. --set model.params.activation=relu")
    
    args = parser.parse_args()
    
    # Load configuration from YAML file
    config = Config.from_yaml(args.config)
    

    for item in args.set or []:
        key, raw = item.split("=", 1)
        config.set(key, yaml.safe_load(raw))

    # Start the rollout process with the configured settings
    measure_flops(config)