import os
import sys
import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import os
import datetime
import wandb
import argparse
from collections import OrderedDict
from tqdm import tqdm
import math
from torch.utils.data import TensorDataset, DataLoader
from model_v4 import DS_CNN_FF
from torchinfo import summary
from pypapi import events as papi_events
import time
import argparse
import ctypes

def load_data(file_path):
    # Load the data from the .npy files using numpy
    ds = np.load(file_path,allow_pickle=True)
    data = []
    label = []
    for i in ds["data"][0]:
        data.append(i)
    data = np.array(data)
    for j in ds["data"][1]:
        label.append(j)
    label = np.array(label)
    train_tensor = torch.tensor(data, dtype=torch.float32).permute(0, 3, 1, 2)
    label_tensor = torch.tensor(label, dtype=torch.long)
    
    # Create a PyTorch TensorDataset
    dataset = TensorDataset(train_tensor, label_tensor)
    
    return dataset

papi_lib = ctypes.CDLL('./papi_test.so')
event = ctypes.c_int(papi_events.PAPI_SP_OPS)
papi_lib.initializePAPI(event)
papi_lib.stopAndRead.restype = ctypes.c_longlong


parser = argparse.ArgumentParser(description='forward-forward-benchmark inf args')
parser.add_argument('--batch_size', type=int, default=1, help='1/32/128')
parser.add_argument('--epochs', type=int, default=5, help='number of epochs to inf (default: 5)')
parser.add_argument('--run_select', default='1,1', help='select GIFF or FF to run. 1,1 means run both')

class Opts:
    args = parser.parse_args()
    batch_size = args.batch_size
    device = 'cpu'
    epochs = args.epochs
    run_select = args.run_select.split(',')
    theta = 8
    label_len = 12


opts = Opts()
config = dict(
    ds_channels =   [[64,64],[64,64],[64,64],[64,64]],
    ds_pooling_2 =  [],
    ds_pooling_3 =  [[1,1],[1,1],[1,1]],
    ds_pooling_4 =  [],
    ds_pooling_5 =  [],
    FF_block_nums = 3.5,
)
#2-1-1
channels = [64,64,64]
model = DS_CNN_FF(config) 
model.load_state_dict(torch.load('your_model_path.pth',map_location=torch.device('cpu')))
model.to(opts.device)
model.eval()

"""
    Step 1: Retrieve Features layers to derive raw activations for each FF block. e.g. Features_seqs[0] is the first FF block
"""
Features = []

for block in model.blocks:
    subblock = []
    for miniblock in block:
        if not isinstance(miniblock, nn.Linear):
            subblock.append(miniblock)
    Features.append(subblock)
            
Features_seqs = [nn.Sequential(*subblock) for subblock in Features]
#print(Features_seqs)

"""
    Step 2: Retrieve Masks' layers for each FF block. e.g. Masks_seqs[0] is the first FF block
"""
Label_FC = []
for block in model.blocks:
    subblock = []
    for miniblock in block:
        if isinstance(miniblock, nn.Linear):
            subblock.append(miniblock)
    Label_FC.append(subblock)
Masks_seqs = [nn.Sequential(*subblock) for subblock in Label_FC]

#print(Masks_seqs)

"""
    Step 3: Preparation work: get the masks' activation ahead of time + PAPI counting to estimate the cost
"""

papi_lib.startCounting()

mask_act_list = [] # len is num_classes.    mask_act_list[0] has num_block act tensor
for label in range(opts.label_len):
    test_label = torch.ones((1,), dtype=torch.long).fill_(label)
    test_label = F.one_hot(test_label, num_classes=opts.label_len)
    test_label = test_label.float()
    mask_act = []
    for i, Masks in enumerate(Masks_seqs):
        acts = Masks(test_label)
        mask_act.append(acts.view(channels[i],config['ds_pooling_3'][i][0],config['ds_pooling_3'][i][1]))

    mask_act_list.append(mask_act)
#print(mask_act_list)

value = ctypes.c_longlong()
print("Preparation: masks results have been stored!!! It costs:")
papi_lib.stopAndRead(ctypes.byref(value))

pools = []
for shape in config['ds_pooling_3']:
    pools.append(nn.AdaptiveAvgPool2d(shape))




"""
    Step x: load images from the dataset for test
"""
val_dataset = load_data("./data/mfcc_test_data.npz")
val_loader = DataLoader(val_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)

has_double_params = any(param.dtype != torch.float32 for param in model.parameters())
if has_double_params:
    exit("Error: double precision parameters are not supported!")
    


"""##########################################
##########################################
### Preparation Ends! Start Inference! ###
##########################################
##########################################"""



for block in Features_seqs:
    has_double_params = any(param.dtype != torch.float32 for param in block.parameters())

for block in mask_act:
    for layer in block:
        has_double_params = any(param.dtype != torch.float32 for param in block)

if has_double_params:
    print("WARNING: The model has double parameters, but the PAPI library is not configured to count double precision operations.")
    exit(1)


if opts.run_select[0] == '1':    
    start_time = time.perf_counter()
    papi_lib.startCounting()
    for epoch in range(opts.epochs):
        for batch_no, (x_in, y_test) in enumerate(val_loader):
            all_goodness = []
            activation_list = []
            final_goodness = torch.zeros((x_in.shape[0],), dtype=torch.float32)
            with torch.no_grad():
                #papi_lib.startCounting()
                # this part is similar to BP without the dense layers
                for index, block in enumerate(Features_seqs):
                    x_in = block(x_in)
                    activation_list.append(pools[index](x_in))

                #papi_lib.stopAndRead(ctypes.byref(value))
                #iterate on each label
                for label, seq in enumerate(mask_act_list):
                    # aggregate the goodness from each block
                    for i, act in enumerate(seq):
                        temp = (activation_list[i] * act).pow(2)
                        final_goodness += temp.sum(dim = [i for i in range(1, temp.dim())])
                    all_goodness.append(final_goodness)
                    final_goodness = 0
                goodness_for_labels = torch.stack(all_goodness, dim=1)
                y_hat = goodness_for_labels.argmin(dim=-1)
    
        # if batch_no == 9:
        #     break
        
    end_time = time.perf_counter()
    print("Test Acc", y_hat.eq(y_test).sum().item()/x_in.shape[0])
    print("FF with GIFF Inference test done in {} batches! It costs:".format(batch_no+1))
    res = papi_lib.stopAndRead(ctypes.byref(value))
    print("batch average fops:", res/(batch_no+1)/opts.epochs)
    print("Avg time elapsed: {:f} ms".format((end_time - start_time)*1000/(batch_no+1)/opts.epochs))

"""
    Please disable the FF part above and run the following part to test the original FF's inference elapsed time
    There might be some cache or memory init effect that will disturb the inference time.
    But for operation number counting, it should be fine.
    I put them in one single file to verify the computation coherence of these two inference type----> should be same accuracy!!!
"""

if opts.run_select[1] == '1':
    start_time = time.perf_counter()
    papi_lib.startCounting()
    print("original FF's inference!!!")
    for epoch in range(opts.epochs):
        for batch_no, (x_in, y_test) in enumerate(val_loader):
            # send image with label $num_class times just as original FF does
            goodness_for_labels = []
            with torch.no_grad():
                for label in range(opts.label_len):
                    test_label = torch.ones((1,), dtype=torch.long).fill_(label)
                    test_label = F.one_hot(test_label, num_classes=opts.label_len)
                    test_label = test_label.float()
                    test_label_repeated = test_label.repeat(x_in.shape[0], 1)
                    #adjust just for none input adjust it for block num
                    acts = model(x_in, test_label_repeated)
                    goodness = acts.sum(dim=[1])
                    goodness_for_labels.append(goodness)
                goodness_for_labels = torch.stack(goodness_for_labels, dim=1)
                y_hat = goodness_for_labels.argmin(dim=-1)

                    
    print("original FF's inference done!!! it costs:")
    print("Acc", y_hat.eq(y_test).sum().item()/x_in.shape[0])
    end_time = time.perf_counter()
    res = papi_lib.stopAndRead(ctypes.byref(value))
    print("there are {} batches".format(batch_no+1))
    print("batch average fops:", res/(batch_no+1)/opts.epochs)
    print("time elapsed: ", (end_time - start_time)*1000/(batch_no+1)/opts.epochs)


#papi_lib.cleanupPAPI()

