from torch.utils.data import DataLoader, random_split, Subset
import os
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import CIFAR10
from torchinfo import summary
from tqdm import tqdm
from util import set_seed, save_model, save_plots, SaveBestModel, cosine_annealing_lr_with_warmup
import torch.nn.functional as F
import numpy as np
import wandb
import argparse
import math
import time
import random

from network_v3 import BP_mobilenet_v1

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
torch.autograd.set_detect_anomaly(True)

# from pypapi import events as papi_events
# import ctypes
# os.environ['LD_LIBRARY_PATH'] = '/home/yourname/papi-install/lib/:' + os.environ.get('LD_LIBRARY_PATH', '')
# papi_lib = ctypes.CDLL('code/papi_test.so')
# event = ctypes.c_int(papi_events.PAPI_SP_OPS)
# papi_lib.initializePAPI(event)
# papi_lib.stopAndRead.restype = ctypes.c_longlong

# papi_lib.startCounting()
# total = np.float32(0.0)
# increment = np.float32(0.1)
# for _ in range(10000):
#     total += increment

# value = ctypes.c_longlong()
# total_cnt = papi_lib.stopAndRead(ctypes.byref(value))
# print(total_cnt)

parser = argparse.ArgumentParser(description='forward-forward-benchmark vww training args')
parser.add_argument('--device', default='cuda', help='cuda|cpu')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train (default: 200)')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 50)')
parser.add_argument('--theta', type=float, default=8, help='layer loss param')
parser.add_argument('--online_visual', type=int, default=0, help='enable wandb')
parser.add_argument('--api_key', default='eec626411e7ff3f4c229c1302489a9df4ab713f9', help='wandb api key default is a wrong one')
parser.add_argument('--dataset_dir', default='data/vw_coco2014_96', help='dataset dir')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
parser.add_argument('--adaptive_lr', type=int, default=1, help='enable adaptive learning rate')
parser.add_argument('--start_lr', type=float, default=0.00005, help='start learning rate for warmup')
parser.add_argument('--warmup_epochs', type=int, default=5, help='warmup epochs')
parser.add_argument('--combo', type=int, default=0, help='selection of model structure')
parser.add_argument('--label_len', type=int, default=2, help='label_length')
parser.add_argument('--fopmonitor', default='0', help='papi 1 enable monitoring')
parser.add_argument('--save_model', type=int, default=1, help='save model')

class Opts:
    args = parser.parse_args()
    batch_size = args.batch_size
    lr = args.lr    
    weight_decay = args.weight_decay
    epochs = args.epochs
    seed = args.seed
    device = args.device
    online_visual = args.online_visual
    api_key = args.api_key
    adaptive_lr = args.adaptive_lr
    startup_lr = args.start_lr
    warmup_epochs = args.warmup_epochs
    log_lr = lr
    dataset_dir = args.dataset_dir
    combo = args.combo
    fopmonitor = args.fopmonitor
    
    if fopmonitor == '1':
        device = 'cpu'
        
    save_model = args.save_model

def load_model(opts):
    #model = Conv_FF_model_v2(combo=combo)
    model = BP_mobilenet_v1(opts.combo)
    return model


def train(network_bp, optimizer, train_loader, loss_fcn, opts):
    running_loss = 0.
    all_outputs = []
    all_labels = []

    # if opts.fopmonitor == '1':
    #     papi_lib.startCounting()
        
    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader)):
        
        x, y_ground = x.float().to(opts.device), y_ground.to(opts.device)
        # x = x.view(opts.batch_size, -1)

        with torch.enable_grad():
            ys = network_bp(x)
            loss = loss_fcn(ys, y_ground)
            loss.backward()
            running_loss += loss.detach()
            optimizer.step()
            optimizer.zero_grad()

        all_outputs.append(torch.nn.functional.softmax(ys,dim=-1).argmax(dim=-1))
        all_labels.append(y_ground)
            
    # if opts.fopmonitor == '1':
    #     value = ctypes.c_longlong()
    #     total_cnt = papi_lib.stopAndRead(ctypes.byref(value))
    #     print("done, averge fops per batch", total_cnt/(batch_no+1))
        
    running_loss /= len(train_loader)
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    correct = all_outputs.eq(all_labels).sum().item()
    return running_loss, correct/len(all_labels)

@torch.no_grad()
def test(network_bp, test_loader, loss_fcn, opts):
    all_outputs = []
    all_labels = []
    test_loss = 0.
    for (x_test, y_test) in test_loader:
        x_test, y_test = x_test.float().to(opts.device), y_test.to(opts.device)
        # x_test = x_test.view(x_test.shape[0], -1)
        acts = network_bp(x_test)
        loss = loss_fcn(acts, y_test)
        test_loss += loss.detach()
        all_outputs.append(torch.nn.functional.softmax(acts, dim=-1).argmax(dim=-1))
        all_labels.append(y_test)

    test_loss /= len(test_loader)
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    # top1 = accuracy(all_outputs, all_labels, topk=(1,))[0]
    correct = all_outputs.eq(all_labels).sum().item()
    return test_loss, correct/len(all_labels)

def main(opts):
    set_seed(opts.seed)

    output_folder = "Output"
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
    
    BASE_DIR = opts.dataset_dir
    validation_split = 0.1
    IMAGE_SIZE = 96
    
    transform_val = transforms.Compose([
        #transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # extra transfrom for the training data, in order to achieve better performance
    transform_train = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.9, 1.1), ratio=(0.9, 1.1)),
        #transforms.RandomCrop(96, padding=0), 
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # For RGB channels, adjust if necessary
    ])
    
    full_dataset = datasets.ImageFolder(root=BASE_DIR)
        
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    train_dataset.dataset.transform = transform_train
    val_dataset.dataset.transform = transform_val
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)
    
    model = load_model(opts).to(opts.device)
    summary(model)
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())
    num_params = count_parameters(model)
    
    est_batchsizes = [1,32,128]
    #peak_mem = (num_params + model.gradient_num + model.activation_num)*4/1024
    for est_batchsize in est_batchsizes:
        peak_mem = (num_params + (model.gradient_num + model.activation_num+ model.error_num)*est_batchsize)*4/1024
        print("~~~~~estimated peak mem~~~~~", est_batchsize)
        print(">>>:", peak_mem, "KB")
    
    
    
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
    train_loss_hist, valid_loss_hist = [], []
    train_acc_hist, valid_acc_hist = [], []
    best_acc = 0
    
    for step in range(opts.epochs):
        for param_group in optimizer.param_groups:
            param_group['lr'] = cosine_annealing_lr_with_warmup(step, opts.epochs, opts.lr, opts.warmup_epochs, opts.startup_lr)
        
        opts.log_lr = optimizer.param_groups[0]['lr']
        #model.train()
        train_loss, train_acc = train(
            model, optimizer, train_loader, loss_fcn, opts)
        
        model.eval()
        valid_loss, valid_acc = test(model, val_loader, loss_fcn, opts)
        print(f"Epoch {step:04d} train_loss: {train_loss:.3f} train_acc: {train_acc:.3f} \
                valid_loss: {valid_loss:.3f} valid_acc: {valid_acc:.3f} log_lr: {opts.log_lr:.6f}")
        if valid_acc > best_acc:
            best_acc = valid_acc
            if opts.save_model == 1:
                print("model saved!!!")
                torch.save(model.state_dict(), 'Output/model_state_InfTest_VWWbp0.pth')
            
        valid_acc_hist.append(valid_acc)
        # save the best model till now if we have the least loss in the current epoch
        # save_best_model(valid_loss, step, model, optimizer, loss_fcn)
        # save the loss and accuracy plots
    save_plots(train_loss_hist, valid_loss_hist, train_acc_hist, valid_acc_hist)
    print('TRAINING COMPLETE! Best Acc:', best_acc)
    print('='*50)
    
if __name__ == '__main__':
    opts = Opts()
    main(opts)