import os
import sys
import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.nn.init as init
import numpy as np
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import transforms, datasets
from pypapi import events as papi_events
import time
import argparse

import ctypes
#os.environ['LD_LIBRARY_PATH'] = '/home/your-name/papi-install/lib/:' + os.environ.get('LD_LIBRARY_PATH', '')
papi_lib = ctypes.CDLL('code/papi_test.so')
event = ctypes.c_int(papi_events.PAPI_SP_OPS)
papi_lib.initializePAPI(event)
papi_lib.stopAndRead.restype = ctypes.c_longlong
# set the thread number = 1
#os.environ['OMP_NUM_THREADS'] = '1'

from network_v3 import FF_mobilenet_v1, BP_mobilenet_v1, Resnet_bp
from network_v1 import MLP_Net_BP

parser = argparse.ArgumentParser(description='forward-forward-benchmark inf args')
parser.add_argument('--batch_size', type=int, default=1, help='1/32/128')
parser.add_argument('--dataset', default='vww', help='dataset name vww/cifar10/mnist')
parser.add_argument('--epochs', type=int, default=5, help='number of epochs to inf (default: 5)')
class Opts:
    args = parser.parse_args()
    act_shapes = [(32,8,8),(128,4,4),(256,2,2)]
    batch_size = args.batch_size
    dataset = args.dataset
    device = 'cpu'
    epochs = args.epochs
    
    if dataset == 'vww':
        dataset_dir = 'data/vw_coco2014_96'
        label_len = 2
    elif dataset == 'cifar10':
        dataset_dir = 'data/cifar10'
        label_len = 10
    elif dataset == 'mnist':
        dataset_dir = 'data'
        label_len = 10

opts = Opts()
if opts.dataset == 'vww':
    model = BP_mobilenet_v1(combo=0)
    model.load_state_dict(torch.load('Output/model_state_InfTest_VWWbp0.pth'))
elif opts.dataset == 'cifar10':
    model = Resnet_bp(combo=4)
    model.load_state_dict(torch.load('Output/model_state_InfTest_cifarResbp1.pth'))
elif opts.dataset == 'mnist':
    model = MLP_Net_BP(n_neurons=[1000,1000,10], in_dim_feat=784, configure = True)
    model.load_state_dict(torch.load('Output/model_state_InfTest_MNISTBP.pth'))

model.to(opts.device)
model.eval()

# conherent with the FF inference (cache init)

"""
    Step x: load images from the dataset for test
"""
BASE_DIR = opts.dataset_dir
if opts.dataset == 'vww':
    full_dataset = datasets.ImageFolder(root=BASE_DIR)
elif opts.dataset == 'cifar10':
    full_dataset = datasets.CIFAR10(root=BASE_DIR, train=False, download=True)
elif opts.dataset == 'mnist':
    full_dataset = datasets.MNIST(root=BASE_DIR, train=False, download=True)
        
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
transform_val = transforms.Compose([
        #transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
    ])

if opts.dataset == 'mnist':
    val_dataset.dataset.transform = transforms.ToTensor()
else:   
    val_dataset.dataset.transform = transform_val

val_loader = DataLoader(val_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)

has_double_params = any(param.dtype != torch.float32 for param in model.parameters())
if has_double_params:
    exit("Error: double precision parameters are not supported!")
    
"""##########################################
##########################################
### Preparation Ends! Start Inference! ###
##########################################
##########################################"""
start_time = time.perf_counter()
papi_lib.startCounting()
for epoch in range(opts.epochs):
    for batch_no, (x_in, y_test) in enumerate(val_loader):
        with torch.no_grad():
            if opts.dataset == 'mnist':
                x_in = x_in.view(x_in.shape[0], -1)
            acts = model(x_in)
            y_hat = torch.nn.functional.softmax(acts,dim=-1).argmax(dim=-1)

end_time = time.perf_counter()    
value = ctypes.c_longlong()
print("Test Acc", y_hat.eq(y_test).sum().item()/x_in.shape[0])
print("FF with BP Inference test done in {} batches! It costs:".format(batch_no+1))
res = papi_lib.stopAndRead(ctypes.byref(value))
print("batch average fops:", res/(batch_no+1)/opts.epochs)
print("Avg time elapsed: {:f} ms".format((end_time - start_time)*1000/(batch_no+1)/opts.epochs))
#papi_lib.cleanupPAPI()

