import os
import sys
import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import os
import datetime
import wandb
import argparse
from collections import OrderedDict
from tqdm import tqdm
import math
from torch.utils.data import TensorDataset, DataLoader
from model import DS_CNN_BP,DS_CNN_BP_v2
from torchinfo import summary
from pypapi import events as papi_events
import time
import argparse
import ctypes

def load_data(file_path):
    # Load the data from the .npy files using numpy
    ds = np.load(file_path,allow_pickle=True)
    data = []
    label = []
    for i in ds["data"][0]:
        data.append(i)
    data = np.array(data)
    for j in ds["data"][1]:
        label.append(j)
    label = np.array(label)
    train_tensor = torch.tensor(data, dtype=torch.float32).permute(0, 3, 1, 2)
    label_tensor = torch.tensor(label, dtype=torch.long)
    
    # Create a PyTorch TensorDataset
    dataset = TensorDataset(train_tensor, label_tensor)
    
    return dataset

papi_lib = ctypes.CDLL('./papi_test.so')
event = ctypes.c_int(papi_events.PAPI_SP_OPS)
papi_lib.initializePAPI(event)
papi_lib.stopAndRead.restype = ctypes.c_longlong


parser = argparse.ArgumentParser(description='forward-forward-benchmark inf args')
parser.add_argument('--batch_size', type=int, default=128, help='1/32/128')
parser.add_argument('--epochs', type=int, default=5, help='number of epochs to inf (default: 5)')

class Opts:
    args = parser.parse_args()
    batch_size = args.batch_size
    device = 'cpu'
    epochs = args.epochs
    

opts = Opts()

model = DS_CNN_BP_v2()
model.load_state_dict(torch.load('your_model_path.pth',map_location=torch.device('cpu')))
model.to(opts.device)
model.eval()

# conherent with the FF inference (cache init)

"""
    Step x: load images from the dataset for test
"""
val_dataset = load_data("./data/mfcc_test_data.npz")
val_loader = DataLoader(val_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)

has_double_params = any(param.dtype != torch.float32 for param in model.parameters())
if has_double_params:
    exit("Error: double precision parameters are not supported!")
    
"""##########################################
##########################################
### Preparation Ends! Start Inference! ###
##########################################
##########################################"""

start_time = time.perf_counter()
papi_lib.startCounting()

for epoch in range(opts.epochs):
    for batch_no, (x_val, y_test) in enumerate(val_loader):
        with torch.no_grad():
            acts = model(x_val)
            y_hat = torch.nn.functional.softmax(acts,dim = 1).argmax(dim=-1)

end_time = time.perf_counter()    
value = ctypes.c_longlong()
print("Test Acc", y_hat.eq(y_test).sum().item()/x_val.shape[0])
print("FF with BP Inference test done in {} batches! It costs:".format(batch_no+1))
res = papi_lib.stopAndRead(ctypes.byref(value))
print("batch average fops:", res/(batch_no+1)/opts.epochs)
print("time elapsed: ", (end_time - start_time)*1000/(batch_no+1)/opts.epochs)
#papi_lib.cleanupPAPI()
