import os
import torch

# Define the base directories where your models are stored
base_dirs = [
    "./",
]

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.cuda.empty_cache()
# Function to load model and compute the number of parameters
def count_parameters(model_path):
    # Load the model state dict
    model_state_dict = torch.load(model_path)
    
    # If the model is an OrderedDict, we need to access parameters as a dictionary
    if isinstance(model_state_dict, dict):
        total_params = sum(p.numel() for p in model_state_dict.values())
    else:
        total_params = sum(p.numel() for p in model_state_dict.parameters())
        
    return total_params

# Function to log the number of parameters for each model
def log_model_params(base_dirs):
    # Open log file to store the parameter counts
    with open("model_params_log.txt", "w") as log_file:
        log_file.write("Model Parameter Count Log\n")
        log_file.write("="*30 + "\n")
        
        for base_dir in base_dirs:
            # Walk through each directory and find final.pth files
            for dirpath, dirnames, filenames in os.walk(base_dir):
                for filename in filenames:
                    # Look for the final.pth files
                    if filename.endswith(".pth") and "bucket_limits" not in filename:
                        model_path = os.path.join(dirpath, filename)

                        try:
                            # Count parameters
                            total_params = count_parameters(model_path)
                            
                            # Log to the file
                            log_file.write(f"Model: {model_path}\n")
                            log_file.write(f"Total Parameters: {total_params}\n")
                            log_file.write("-"*30 + "\n")
                            
                            print(f"Logged: {model_path} - Parameters: {total_params}")
                        except Exception as e:
                            log_file.write(f"Failed to load model {model_path}: {e}\n")
                            print(f"Failed to load model {model_path}: {e}")
                            
    print("Logging completed!")

# Run the logging function
log_model_params(base_dirs)