#test.py
#!/usr/bin/env python3

""" test neuron network performace
print top1 and top5 err on test dataset
of a model
"""

import argparse

from matplotlib import pyplot as plt

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from conf import settings
from utils import get_network, get_test_dataloader
import os
import csv
import pandas as pd

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-net', type=str, required=True, help='net type')
    parser.add_argument('-weights', type=str, required=True, help='the weights file you want to test')
    parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
    parser.add_argument('-b', type=int, default=1, help='batch size for dataloader')
    parser.add_argument('-kpath', type=str, required=True, help='the path to the saved kmeans model')
    parser.add_argument('-k', '--k_values', type=int, nargs='+', 
                        default=[5,10,15,20,25,30,40,50,60,70,80,90,100,120,140,160,180,200,240,280,320],
                        help='List of k values for k-means clustering')
    parser.add_argument('-tpath', type=str, default='taylor_parameters',
                        help='Directory to load taylor expansion related parameters')
    parser.add_argument('-o', '--output', type=str, default='accuracy_results.csv',
                        help='Output CSV file name for results')
    args = parser.parse_args()
    results = []
    net = get_network(args)

    cifar100_test_loader = get_test_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        #settings.CIFAR100_PATH,
        num_workers=4,
        batch_size=args.b,
    )

    net.load_state_dict(torch.load(args.weights))
    print(net)
    net.eval()
    total_samples = len(cifar100_test_loader.dataset)
    for k in args.k_values:
        print(f'------------------------k={k}------------------------')
        net.reset()
        file_name = os.path.join(args.kpath,f'kmeans_model_{k}.joblib')

        net.Load_Kmeans(file_name)
        net.taylor_enable = True
        net.Taylor_Load(args.tpath,k)
        net.mse_enable = True
        

        correct_1 = 0.0
        correct_5 = 0.0
        total = 0
        
        with torch.no_grad():
            for n_iter, (image, label) in enumerate(cifar100_test_loader):
                print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(cifar100_test_loader)))

                if args.gpu:
                    image = image.cuda()
                    label = label.cuda()
                    print('GPU INFO.....')
                    print(torch.cuda.memory_summary(), end='')


                output = net(image)
                _, pred = output.topk(5, 1, largest=True, sorted=True)

                label = label.view(label.size(0), -1).expand_as(pred)
                correct = pred.eq(label).float()

                #compute top 5
                correct_5 += correct[:, :5].sum()

                #compute top1
                correct_1 += correct[:, :1].sum()

        if args.gpu:
            print('GPU INFO.....')
            print(torch.cuda.memory_summary(), end='')

        print()
        print("Top 1 err: ", 1 - correct_1 / len(cifar100_test_loader.dataset))
        print("Top 5 err: ", 1 - correct_5 / len(cifar100_test_loader.dataset))
        print("MSE: ", net.mse / len(cifar100_test_loader.dataset))


        top1_acc = correct_1 / total_samples
        top5_acc = correct_5 / total_samples
        top1_err = 1 - top1_acc
        top5_err = 1 - top5_acc
        avg_mse = net.mse / total_samples

        

        result = {
            'k': k,
            'top1_accuracy': top1_acc.item(),
            'top5_accuracy': top5_acc.item(),
            'top1_error': top1_err.item(),
            'top5_error': top5_err.item(),
            'mse': avg_mse.item(),
        }
        results.append(result)
    output_path = args.output
    with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
        if results:
            fieldnames = results[0].keys()
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(results)
    
    try:
        excel_path = output_path.replace('.csv', '.xlsx')
        df = pd.DataFrame(results)
        df.to_excel(excel_path, index=False)
        print(f"\nsaved to:")
        print(f"CSV file: {output_path}")
        print(f"Excel file: {excel_path}")
    except ImportError:
        print(f"\nsaved to: {output_path}")


    print("\n=== summary ===")
    print(f"{'k':<6} {'Top1 acc':<12} {'Top5 acc':<12} {'MSE':<12}")
    print("-" * 45)
    for result in results:
        print(f"{result['k']:<6} {result['top1_accuracy']:<12.4f} {result['top5_accuracy']:<12.4f} {result['mse']:<12.6f}")