import sys
import pandas as pd
import numpy as np
import nninfo
import torch
import os.path
from filelock import FileLock, Timeout

def compute_loss_accuracy(exp_name, n_quantization_levels, quantizer_params, run_id, dataset, rounding_point):
    
    #torch.set_num_threads(nninfo.config.N_WORKERS)
    #torch.set_num_interop_threads(nninfo.config.N_WORKERS)
    print(f'{torch.get_num_threads()=}')
    print(f'{torch.get_num_interop_threads()=}')

    # Load experiment
    exp = nninfo.exp.Experiment(exp_name, load=True)

    # Initialize analysis
    print('Initializing analysis')
    analysis = nninfo.analysis.Analysis(exp)
    print('Analysis initialized')

    # Compute loss/accuracy
    epochs, perf_dict = analysis.compare_acc_loss_multiple_datasets(
        [dataset], run_ids=[run_id], show_plot=False, quantizer_params=quantizer_params)

    # Collect results
    acc = perf_dict[dataset]['acc']
    loss = perf_dict[dataset]['loss']

    print('Run: ', run_id, 'Dataset: ', dataset, 'Final accuracy: ', acc[-1])

    file_path = 'experiments/exp_{}/performance.pkl'.format(exp_name)

    # Write results to file (with lock to prevent race conditions in parallel execution)
    lock = FileLock(file_path+".lock", timeout=-1)
    with lock.acquire():
        performance = pd.read_pickle(file_path) if os.path.isfile(
            file_path) else pd.DataFrame()
        performance = performance.append(pd.DataFrame({'run_id': [run_id] * len(epochs), 'epoch': epochs, 'n_quantization_levels': [n_quantization_levels]*len(
            epochs), 'acc': acc, 'loss': loss, 'dataset_name': [dataset]*len(epochs), 'rounding_point': [rounding_point]*len(epochs)}))
        performance.to_pickle(file_path)


def main(argv):
    
    print('Computing loss & Accuracy')

    exp_name = str(argv[1])
    n_quantization_levels = int(argv[2])
    run_id = int(argv[3])
    dataset = str(argv[4])
    rounding_point = str(argv[5]) if len(argv) > 5 else 'center_saturating'

    quantizer_nodropout = 8 * [None] + 5 * [{'levels': 8, 'dequant_point': 'center_saturating'}]
    
    compute_loss_accuracy(exp_name, n_quantization_levels, quantizer_nodropout, run_id, dataset, rounding_point)

if __name__ == '__main__':
    main(sys.argv)
