import os
import sys
import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.nn.init as init
import numpy as np
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import transforms, datasets
from pypapi import events as papi_events
from torchinfo import summary
import time
import argparse

import ctypes
#os.environ['LD_LIBRARY_PATH'] = '/home/yourname/papi-install/lib/:' + os.environ.get('LD_LIBRARY_PATH', '')
papi_lib = ctypes.CDLL('code/papi_test.so')
event = ctypes.c_int(papi_events.PAPI_SP_OPS)
papi_lib.initializePAPI(event)
papi_lib.stopAndRead.restype = ctypes.c_longlong

# As a small example, we will count the number of floating point operations
# papi_lib.startCounting()

# total = np.float32(0.0)
# increment = np.float32(0.1)
# for _ in range(10000):
#     total += increment

# # Stop counting and read the value
# value = ctypes.c_longlong()
# papi_lib.stopAndRead(ctypes.byref(value))

# from pypapi import papi_high, events, papi_low

from network_v3 import FF_mobilenet_v1, Resnet_ff_new, FF_MLP_Net


parser = argparse.ArgumentParser(description='forward-forward-benchmark inf args')
parser.add_argument('--batch_size', type=int, default=128, help='1/32/128')
parser.add_argument('--dataset', default='vww', help='dataset name vww/cifar10/mnist')
parser.add_argument('--run_select', default='0,1', help='select GIFF or FF to run. 1,1 means run both')
parser.add_argument('--epochs', type=int, default=5, help='number of epochs to inf (default: 5)')

class Opts:
    args = parser.parse_args()
    batch_size = args.batch_size
    dataset = args.dataset
    device = 'cpu'
    run_select = args.run_select.split(',')
    epochs = args.epochs
    
    if dataset == 'vww':
        dataset_dir = 'data/vw_coco2014_96'
        label_len = 2
    elif dataset == 'cifar10':
        dataset_dir = 'data/cifar10'
        label_len = 10
    elif dataset == 'mnist':
        dataset_dir = 'data'
        label_len = 10
    
    if dataset == 'vww':
        act_shapes = [(32,6,6),(128,2,2),(256,2,2)] # for vww
        mask_shapes = [(4,4),(3,3),(2,1)]
        diff_res = np.array([1,1,1])
    elif dataset == 'cifar10':
        act_shapes = [(32,4,4),(64,3,3)] # for cifar
        diff_res = np.array([1,1]) # in ff training, it is each ff block's learning perf
    elif dataset == 'mnist':
        act_shapes = [(32,1,1)]
        diff_res = np.array([1,1])

opts = Opts()
if opts.dataset == 'vww':
    model = FF_mobilenet_v1(combo=0,pool_list=[6,2,2])
    model.load_state_dict(torch.load('Output/model_state_InfTest_VWWmb0.pth'))
elif opts.dataset == 'cifar10':
    model = Resnet_ff_new(combo=0)
    model.load_state_dict(torch.load('Output/model_state_InfTest_cifarResff1.pth'))
elif opts.dataset == 'mnist':
    model = FF_MLP_Net(nn_strct=[784,1000,1000])
    model.load_state_dict(torch.load('Output/model_state_InfTest_MNISTFF.pth'))
#Output/model_state_InfTest_cifarResbp1.pth
model.to(opts.device)
model.eval()

"""
    Step 1: Retrieve Features layers to derive raw activations for each FF block. e.g. Features_seqs[0] is the first FF block
"""
Features = []

for block in model.blocks:
    subblock = []
    for miniblock in block:
        if not isinstance(miniblock, nn.Linear):
            subblock.append(miniblock)
    Features.append(subblock)
            
Features_seqs = [nn.Sequential(*subblock) for subblock in Features]
#print(Features_seqs)

"""
    Step 2: Retrieve Masks' layers for each FF block. e.g. Masks_seqs[0] is the first FF block
"""
Label_FC = []
for block in model.blocks:
    subblock = []
    for miniblock in block:
        if isinstance(miniblock, nn.Linear):
            subblock.append(miniblock)
    Label_FC.append(subblock)
Masks_seqs = [nn.Sequential(*subblock) for subblock in Label_FC]

#print(Masks_seqs)

"""
    Step 3: Preparation work: get the masks' activation ahead of time + PAPI counting to estimate the cost
"""
#papi_lib.startCounting()

mask_act_list = [] # len is num_classes.    mask_act_list[0] has num_block act tensor
for label in range(opts.label_len):
    test_label = torch.ones((1,), dtype=torch.long).fill_(label)
    test_label = F.one_hot(test_label, num_classes=opts.label_len)
    test_label = test_label.float()
    mask_act = []
    for i, Masks in enumerate(Masks_seqs):
        acts = Masks(test_label)
        if opts.dataset != 'mnist':
            mask_act.append(acts.view(opts.act_shapes[i]))
        else:
            mask_act.append(acts)
    mask_act_list.append(mask_act)
#print(mask_act_list)

value = ctypes.c_longlong()
# print("Preparation: masks results have been stored!!! It costs:")
# papi_lib.stopAndRead(ctypes.byref(value))

pools = []
for shape in opts.mask_shapes:
    pools.append(nn.MaxPool2d(kernel_size=shape[0], stride=shape[1]))
# for shape in opts.act_shapes:
#     pools.append(nn.AdaptiveAvgPool2d(output_size=shape[1:]))

"""
    Step 4: load images from the dataset for test
"""
BASE_DIR = opts.dataset_dir
if opts.dataset == 'vww':
    full_dataset = datasets.ImageFolder(root=BASE_DIR)
elif opts.dataset == 'cifar10':
    full_dataset = datasets.CIFAR10(root=BASE_DIR, train=False, download=True)
elif opts.dataset == 'mnist':
    full_dataset = datasets.MNIST(root=BASE_DIR, train=False, download=True)
        
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
transform_val = transforms.Compose([
        #transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
    ])
if opts.dataset == 'mnist':
    val_dataset.dataset.transform = transforms.ToTensor()
else:
    val_dataset.dataset.transform = transform_val
val_loader = DataLoader(val_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)



"""##########################################
##########################################
### Preparation Ends! Start Inference! ###
##########################################
##########################################"""



for block in Features_seqs:
    has_double_params = any(param.dtype != torch.float32 for param in block.parameters())

for block in mask_act:
    for layer in block:
        has_double_params = any(param.dtype != torch.float32 for param in block)

if has_double_params:
    print("WARNING: The model has double parameters, but the PAPI library is not configured to count double precision operations.")
    exit(1)


if opts.run_select[0] == '1':    
    start_time = time.perf_counter()
    papi_lib.startCounting()
    for epoch in range(opts.epochs):
        for batch_no, (x_in, y_test) in enumerate(val_loader):
            if opts.dataset == 'mnist':
                x_in = x_in.view(x_in.shape[0],-1)
            all_goodness = []
            activation_list = []
            final_goodness = torch.zeros((x_in.shape[0],), dtype=torch.float32)
            with torch.no_grad():
                #papi_lib.startCounting()
                # this part is similar to BP without the dense layers
                for index, block in enumerate(Features_seqs):
                    x_in = block(x_in)
                    if opts.dataset != 'mnist':
                        activation_list.append(pools[index](x_in))
                    else:
                        activation_list.append(x_in)
                #papi_lib.stopAndRead(ctypes.byref(value))
                #iterate on each label
                for label, seq in enumerate(mask_act_list):
                    # aggregate the goodness from each block
                    for i, act in enumerate(seq):
                        temp = (activation_list[i] * act).pow(2)
                        final_goodness += temp.sum(dim = [i for i in range(1, temp.dim())])
                    all_goodness.append(final_goodness)
                    final_goodness = 0
                goodness_for_labels = torch.stack(all_goodness, dim=1)
                if opts.dataset != 'mnist':
                    y_hat = goodness_for_labels.argmin(dim=-1)
                else:
                    y_hat = goodness_for_labels.argmax(dim=-1)
        # if batch_no == 9:
        #     break
        
    end_time = time.perf_counter()
    print("Test Acc", y_hat.eq(y_test).sum().item()/x_in.shape[0])
    print("FF with GIFF Inference test done in {} batches! It costs:".format(batch_no+1))
    res = papi_lib.stopAndRead(ctypes.byref(value))
    print("batch average fops:", res/(batch_no+1)/opts.epochs)
    print("Avg time elapsed: {:f} ms".format((end_time - start_time)*1000/(batch_no+1)/opts.epochs))

"""
    Please disable the FF part above and run the following part to test the original FF's inference elapsed time
    There might be some cache or memory init effect that will disturb the inference time.
    But for operation number counting, it should be fine.
    I put them in one single file to verify the computation coherence of these two inference type----> should be same accuracy!!!
"""

if opts.run_select[1] == '1':
    start_time = time.perf_counter()
    papi_lib.startCounting()
    print("original FF's inference!!!")
    for epoch in range(opts.epochs):
        for batch_no, (x_in, y_test) in enumerate(val_loader):
            # send image with label $num_class times just as original FF does
            if opts.dataset == 'mnist':
                x_in = x_in.view(x_in.shape[0],-1)
            goodness_for_labels = []
            with torch.no_grad():
                for label in range(opts.label_len):
                    test_label = torch.ones((1,), dtype=torch.long).fill_(label)
                    test_label = F.one_hot(test_label, num_classes=opts.label_len)
                    test_label = test_label.float()
                    test_label_repeated = test_label.repeat(x_in.shape[0], 1)
                    #adjust just for none input adjust it for block num
                    diff_res = opts.diff_res
                    acts = model(x_in, test_label_repeated, opts, diff_res)
                    goodness = acts.sum(dim=[1])
                    goodness_for_labels.append(goodness)
                goodness_for_labels = torch.stack(goodness_for_labels, dim=1)
                if opts.dataset != 'mnist':
                    y_hat = goodness_for_labels.argmin(dim=-1)
                else:
                    y_hat = goodness_for_labels.argmax(dim=-1)
                    
    print("original FF's inference done!!! it costs:")
    print("Acc", y_hat.eq(y_test).sum().item()/x_in.shape[0])
    end_time = time.perf_counter()
    res = papi_lib.stopAndRead(ctypes.byref(value))
    print("there are {} batches".format(batch_no+1))
    print("batch average fops:", res/(batch_no+1)/opts.epochs)
    print("time elapsed: ", (end_time - start_time)*1000/(batch_no+1)/opts.epochs)


#papi_lib.cleanupPAPI()

