import os
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchinfo import summary
from tqdm import tqdm
from util import set_seed, save_model, save_plots, SaveBestModel, cosine_annealing_lr_with_warmup
from torch.utils.data import DataLoader
import math
from network_v2 import MLP_Net_Receptive, Conv_block_bp, VGG16
from network_v3 import Resnet_bp

# from pypapi import events as papi_events
# import ctypes
# os.environ['LD_LIBRARY_PATH'] = '/home/jlin445/papi-install/lib/:' + os.environ.get('LD_LIBRARY_PATH', '')
# papi_lib = ctypes.CDLL('code/papi_test.so')
# event = ctypes.c_int(papi_events.PAPI_SP_OPS)
# papi_lib.initializePAPI(event)
# papi_lib.stopAndRead.restype = ctypes.c_longlong


class Opts:    
    batch_size = 256

    lr = 0.001
    weight_decay = 0
    epochs = 500
    warmup_epochs = 5
    startup_lr = 0.0001
    seed = 0
    device = 'cuda'
    combo = 4
    fopmonitor = '0'
    if fopmonitor == '1':
        device = 'cpu'
    
def train(network_bp, optimizer, train_loader, loss_fcn, opts):
    running_loss = 0.
    all_outputs = []
    all_labels = []

    # if opts.fopmonitor == '1':
    #     papi_lib.startCounting()
            
    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x, y_ground = x.float().to(opts.device), y_ground.to(opts.device)
        # x = x.view(opts.batch_size, -1)
        
        with torch.enable_grad():
            ys = network_bp(x)
            loss = loss_fcn(ys, y_ground)
            loss.backward()
            running_loss += loss.detach()

        optimizer.step()
        optimizer.zero_grad()

        all_outputs.append(torch.nn.functional.softmax(ys).argmax(dim=-1))
        all_labels.append(y_ground)
    
    # if opts.fopmonitor == '1':
    #     value = ctypes.c_longlong()
    #     total_cnt = papi_lib.stopAndRead(ctypes.byref(value))
    #     print("done, averge fops per batch", total_cnt/(batch_no+1))

    running_loss /= len(train_loader)
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    correct = all_outputs.eq(all_labels).sum().item()
    return running_loss, correct/len(all_labels)

@torch.no_grad()
def test(network_bp, test_loader, loss_fcn, opts):
    all_outputs = []
    all_labels = []
    test_loss = 0.
    for (x_test, y_test) in test_loader:
        x_test, y_test = x_test.float().to(opts.device), y_test.to(opts.device)
        # x_test = x_test.view(x_test.shape[0], -1)
        acts = network_bp(x_test)
        loss = loss_fcn(acts, y_test)
        test_loss += loss.detach()
        all_outputs.append(torch.nn.functional.softmax(acts).argmax(dim=-1))
        all_labels.append(y_test)

    test_loss /= len(test_loader)
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    # top1 = accuracy(all_outputs, all_labels, topk=(1,))[0]
    correct = all_outputs.eq(all_labels).sum().item()
    return test_loss, correct/len(all_labels)

def load_model(opts):
    if opts.combo >= 4:
        model = Resnet_bp(combo = opts.combo)
    else:
        model = Conv_block_bp(combo=opts.combo)
    return model

def main(opts):
    set_seed(opts.seed)

    output_folder = "Output"
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    # standard cast into Tensor and pixel values normalization in [-1, 1] range
    transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        #transforms.Normalize(*stats, inplace=True)
    ])
    
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor()
    ])

    train_loader = DataLoader(CIFAR10(root='data/cifar10', train=True, download=True, transform=train_transform),
                            batch_size=opts.batch_size,
                            shuffle=True,
                            num_workers=4
                            #drop_last=True
                           )
    test_loader = DataLoader(CIFAR10(root='data/cifar10', train=False, download=True, transform=transform),
                            batch_size=opts.batch_size,
                            shuffle=False,
                            num_workers=4
                            #drop_last=True
                            )
    model = load_model(opts).to(opts.device)
    summary(model)
    
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())
    num_params = count_parameters(model)
    peak_mem = (num_params + model.gradient_num + model.activation_num + model.error_num)*4/1024
    #peak_mem = (num_params + 0 + model.activation_num)*4/1024
    print("~~~~~estimated peak mem~~~~~")
    print(">>>:", peak_mem, "KB")
    
    
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
    # save_best_model = SaveBestModel(
    #     name=f'CIFAR10-BP_lr={str(opts.lr)}_quant={str(opts.quantized)}_W{str(opts.weight_precision)}_B{str(opts.bias_precision)}_A{str(opts.act_precision)}')
    train_loss_hist, valid_loss_hist = [], []
    train_acc_hist, valid_acc_hist = [], []
    best_acc = 0
    for step in range(opts.epochs):
        for param_group in optimizer.param_groups:
            param_group['lr'] = cosine_annealing_lr_with_warmup(step, opts.epochs, opts.lr, opts.warmup_epochs, opts.startup_lr)
        
        opts.log_lr = optimizer.param_groups[0]['lr']
        model.train()
        train_loss, train_acc = train(
            model, optimizer, train_loader, loss_fcn, opts)
        
        model.eval()
        valid_loss, valid_acc = test(model, test_loader, loss_fcn, opts)
        print(f"Epoch {step:04d} train_loss: {train_loss:.3f} train_acc: {train_acc:.3f} \
                valid_loss: {valid_loss:.3f} valid_acc: {valid_acc:.3f} log_lr: {opts.log_lr:.6f}")

        train_loss_hist.append(train_loss.cpu())
        train_acc_hist.append(train_acc)
        valid_loss_hist.append(valid_loss.cpu())
        valid_acc_hist.append(valid_acc)
        # save the best model till now if we have the least loss in the current epoch
        # save_best_model(valid_loss, step, model, optimizer, loss_fcn)
        # save the loss and accuracy plots
        if valid_acc > best_acc:
            best_acc = valid_acc
            torch.save(model.state_dict(), './Output/model_state_InfTest_cifarResbp1.pth')

    print('TRAINING COMPLETE!')
    print('='*50)
    # save_plots(f'CIFAR10-BP_lr={str(opts.lr)}',
    #            train_acc_hist, valid_acc_hist,
    #            train_loss_hist, valid_loss_hist)
    # best_model_cp = torch.load(
    #     f'Output/CIFAR10-BP_lr={str(opts.lr)}_best_model.pth')
    # best_model_epoch = best_model_cp['epoch']
    # print(f"Best model was saved at {best_model_epoch} epochs\n")
    # model.load_state_dict(best_model_cp['model_state_dict'])
    # test_loss, test_acc = test(model, test_loader, loss_fcn, opts)
    # print(f"Best model test accuracy: {test_acc:.3f}")
    # print('='*50)


if __name__ == '__main__':
    opts = Opts()
    main(opts)
