import os
import numpy as np
import torch
import torch.nn as nn
from network_v1 import MLP_Net_BP
import argparse

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import MNIST
from torchinfo import summary
from tqdm import tqdm
from util import accuracy, set_seed, save_model, save_plots, SaveBestModel

from network_v3 import mnist_conv2d_bp


from pypapi import events as papi_events
import ctypes
os.environ['LD_LIBRARY_PATH'] = '/home/jlin445/papi-install/lib/:' + os.environ.get('LD_LIBRARY_PATH', '')
papi_lib = ctypes.CDLL('code/papi_test.so')
event = ctypes.c_int(papi_events.PAPI_SP_OPS)
papi_lib.initializePAPI(event)
papi_lib.stopAndRead.restype = ctypes.c_longlong


class MNISTDataset(Dataset):
    def __init__(self, train=True, transform=None):
        """
        Args:
            train (bool): Indicates whether to load the training or test set.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.mnist_data = MNIST(root='./data', train=train, download=True,
                                         transform=transform)
        self.transform = transform

    def __len__(self):
        return len(self.mnist_data)

    def __getitem__(self, index):
        image, label = self.mnist_data[index]
        return image, label
    
def load_mnist_data(batchsize = 128, train=True, val=True):
    """
    Load MNIST data into train and test sets.
    Returns a list of data loaders
    """
    print("start building MNIST data loaders")
    ds = []
    if train:
        Mnist_train_data = MNISTDataset(train=True, transform=transforms.ToTensor())
        Mnist_train_loader = DataLoader(Mnist_train_data, batch_size=batchsize, shuffle=True)
        ds.append(Mnist_train_loader)
    if val:
        Mnist_test_data = MNISTDataset(train=False, transform=transforms.ToTensor())
        Mnist_test_loader = DataLoader(Mnist_test_data, batch_size=batchsize, shuffle=False)
        ds.append(Mnist_test_loader)
    ds = ds[0] if len(ds) == 1 else ds 
    # for data_loader in ds:
    #     data_loader.dataset.mnist_data.targets = to_one_hot(data_loader.dataset.mnist_data.targets, num_classes=10)
    return ds

def load_model(opts=None):
    if opts.cnn != 1:
        model = MLP_Net_BP(n_neurons=opts.nn_strct, in_dim_feat=120, configure=True) # number of neurons in all hidden layer = 10 784 = 28*28
    else:
        model = mnist_conv2d_bp(0)
    return model

@torch.no_grad()
def test(network, test_loader, opts):
    all_outputs = []
    all_labels = []
    for (x_test, y_test) in test_loader:
        x_test, y_test = x_test.to(opts.device), y_test.to(opts.device)
        if opts.cnn != 1:
            x_test = x_test.view(x_test.shape[0], -1)
        acts = network(x_test)
        all_outputs.append(torch.nn.functional.softmax(acts).argmax(dim=-1))
        all_labels.append(y_test)

    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    correct = all_outputs.eq(all_labels).sum().item()
    return correct/len(all_labels)

def train(network, optimizer, train_loader, valid_loader, loss_fcn, opts):
    running_loss = 0.
    all_outputs = []
    all_labels = []
    
    if opts.fopmonitor == '1':
        papi_lib.startCounting()
            
    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader)):

        x, y_ground = x.to(opts.device), y_ground.to(opts.device)
        if opts.cnn != 1:
            x = x.view(x.shape[0], -1)
        with torch.enable_grad():
            ys = network(x)
            loss = loss_fcn(ys, y_ground)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        running_loss += loss.detach()
    
    if opts.fopmonitor == '1':
        value = ctypes.c_longlong()
        result = papi_lib.stopAndRead(ctypes.byref(value))
        print("done", result/(batch_no+1))        

    running_loss /= len(train_loader)
    
    train_acc = test(network, train_loader, opts)
    valid_acc = test(network, valid_loader, opts)
    return running_loss, train_acc, valid_acc


parser = argparse.ArgumentParser(description='forward-forward-benchmark training args')
parser.add_argument('--device', default='cuda', help='cuda|cpu')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 200)')
parser.add_argument('--batch_size', type=int, default=128, help='input batch size for training (default: 50)')
parser.add_argument('--nn_strct', default='100,100,100', help='neural network structure')
parser.add_argument('--online_visual', type=int, default=0, help='enable wandb')
parser.add_argument('--api_key', default='xxxxxxxxxxxxxxxxxx', help='wandb api key default is a wrong one')

class Opts:
    args = parser.parse_args()
    lr = 0.001
    weight_decay = 0
    epochs = 100
    nn_strct = [int(i) for i in args.nn_strct.split(',')]
    online_visual = args.online_visual
    seed = 0
    device = args.device
    cnn = 1
    fopmonitor = '0'
    batch_size = args.batch_size
    
    if fopmonitor == '1':
        device = 'cpu'

def main(opts):
    set_seed(opts.seed)
    train_loader, test_loader = load_mnist_data(batchsize=opts.batch_size)
  
    model = load_model(opts).to(opts.device)
    summary(model)
    
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())
    
    num_params = count_parameters(model)
    peak_mem = 0
    mem_dist = []
    peak_mem = (num_params + model.gradient_num + model.activation_num + model.error_num)*4/1024
    print("~~~~~estimated peak mem~~~~~")
    print(">>>:", peak_mem, "KB")
    #print(mem_dist)
    
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr)
    
    train_loss_hist = []
    # valid_loss_hist = []
    train_acc_hist, valid_acc_hist = [], []
    best_acc = 0
    
    for step in range(opts.epochs):
        train_loss, train_acc, valid_acc = train(model, optimizer, train_loader, test_loader, loss_fcn, opts)
        # valid_loss, valid_acc = test(model, valid_loader, loss_fcn, opts)
        print(f"Epoch {step:04d} train_loss: {train_loss:.3f} \
                train_acc: {train_acc:.3f} \
                valid_acc: {valid_acc:.3f}")
        train_loss_hist.append(train_loss.cpu())
        train_acc_hist.append(train_acc)
        # valid_loss_hist.append(valid_loss.cpu())
        valid_acc_hist.append(valid_acc)
        # save the best model till now if we have the least loss in the current epoch
        # save_best_model(valid_loss, step, model, optimizer, loss_fcn)
        # save the loss and accuracy plots
        if opts.online_visual == 1:
            wandb.log({"train_loss": train_loss, "train_acc": train_acc, "valid_acc": valid_acc})
        
        if valid_acc > best_acc:
            best_acc = valid_acc
            torch.save(model.state_dict(), 'Output/model_state_InfTest_MNISTBP.pth')
    
    print('TRAINING COMPLETE!')
    print('='*50)
    
    # plot_filename = "MINST_BP_lr"+str(opts.lr)+"_"+"epochs_"+str(opts.epochs)+"_bs_"+str(opts.batch_size) \
    # +"_nn_strct_"+str(opts.nn_strct)
    # save_plots(plot_filename,train_acc_hist,valid_acc_hist,train_loss_hist,[])

if __name__ == '__main__':
    opts = Opts()
    
    if opts.online_visual == 1:
        import wandb
        os.environ["WANDB_API_KEY"] = opts.api_key    
        wandb_prjname = "FFbenchmark"+"_"+"MNIST"
        wandb_runname = "MINST_BP_lr"+str(opts.lr)+"_"+"epochs_"+str(opts.epochs)+"_bs_"+str(opts.batch_size)+"_nn_strct_"+str(opts.nn_strct)
        wandb.init(project=wandb_prjname, name=wandb_runname)
    
    main(opts)

    if opts.online_visual == 1:
        wandb.finish()
