#test.py
#!/usr/bin/env python3

""" 
kmeans data collect
here we will collect the Intermediate input as data for kmeans
we will use a trained model and run on the training datasets 
and collect the Intermediate input through the self.kmeans_data = []
"""

import argparse

from matplotlib import pyplot as plt

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans
from conf import settings
from utils import get_network, get_training_dataloader
import os
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-net', type=str, required=True, help='net type')
    parser.add_argument('-weights', type=str, required=True, help='the weights file you want to test')
    parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
    parser.add_argument('-b', type=int, default=16, help='batch size for dataloader')
    parser.add_argument('-kpath', type=str, required=True, help='the path to the saved kmeans model')
    parser.add_argument('-k', '--k_values', type=int, nargs='+', 
                        default=[5,10,15,20,25,30,40,50,60,70,80,90,100,120,140,160,180,200,240,280,320],
                        #default=[5],
                        help='List of k values for k-means clustering')
    parser.add_argument('-o', '--output_dir', type=str, default='taylor_parameters',
                        help='Directory to save taylor expansion related parameters')
    args = parser.parse_args()

    net = get_network(args)
    net.load_state_dict(torch.load(args.weights))
    print(net)
    net.eval()


    for k in args.k_values:
        file_name = os.path.join(args.kpath,f'kmeans_model_{k}.joblib')
        net.Load_Kmeans(file_name)
        net.Taylor_Learn(args.output_dir)
       
