#test.py
#!/usr/bin/env python3

""" test neuron network performace
print top1 and top5 err on test dataset
of a model

"""

import argparse

from matplotlib import pyplot as plt

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from conf import settings
from utils import get_network, get_test_dataloader
import os
import pandas as pd
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-net', type=str, required=True, help='net type')
    parser.add_argument('-weights', type=str, required=True, help='the weights file you want to test')
    parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
    parser.add_argument('-b', type=int, default=1, help='batch size for dataloader')
    parser.add_argument('-kpath', type=str, required=True, help='the path to the saved kmeans model')
    parser.add_argument('-k', '--k_values', type=int, nargs='+', 
                        default=[5,10,15,20,25,30,40,50,60,70,80,90,100,120,140,160,180,200,240,280,320],
                        help='List of k values for k-means clustering')
    parser.add_argument('-tpath', type=str, default='taylor_parameters',
                        help='Directory to load taylor expansion related parameters')
    parser.add_argument('-o', '--output', type=str, default='time_results.xlsx',
                        help='Output xlsx filename for timing results')
    args = parser.parse_args()

    net = get_network(args)
    torch.set_num_threads(1)
    cifar100_test_loader = get_test_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        #settings.CIFAR100_PATH,
        num_workers=4,  
        batch_size=1,
    )

    net.load_state_dict(torch.load(args.weights))
    print(net)
    net.eval()
    timing_data = []
    with torch.no_grad():
        for k in args.k_values:
            print(f'------------------------k={k}------------------------')
            file_name = os.path.join(args.kpath,f'kmeans_model_{k}.joblib')
            net.Load_Kmeans(file_name)
            
            net.Taylor_Load(args.tpath,k)
            net.mse_enable = True
            
            net.reset()

            correct_1 = 0.0
            correct_5 = 0.0
            total = 0
            max_samples = 1024
            processed = 0

            for n_iter, (image, label) in enumerate(cifar100_test_loader):
                
                batch_size = image.size(0)
                processed += batch_size
                if processed > max_samples:
                    break
                net.taylor_enable = False
                net.timer_enbale = True
                output = net(image)

                net.taylor_enable = True
                net.timer_enbale = False
                output = net(image)


            timing_info = {
                'k_value': k,
                'classifier_time': net.classifier_time/max_samples,
                'kmeans_predict_time': net.kmeans_predict_time/max_samples,
                'taylor_forward_time': net.taylor_forward_time/max_samples,
                'conv_time': net.conv_time / max_samples,

            }
            
            timing_data.append(timing_info)
            print()
            print(f'classifier_time: {net.classifier_time/max_samples}')
            print(f'kmeans_predict_time: {net.kmeans_predict_time/max_samples}')
            print(f'taylor_forward_time: {net.taylor_forward_time/max_samples}')
            print(f'conv time {net.conv_time / max_samples}')

    df = pd.DataFrame(timing_data)
    
    output_dir = os.path.dirname(args.output)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    

    df.to_excel(args.output, index=False, engine='openpyxl')
    print(f"\nsaved to: {args.output}")
    
    # 打印数据摘要
    print("\nsummary:")
    print(df.describe())