# Just the FF/BP baseline following Hinton's paper

from torch.utils.data import DataLoader,random_split
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchinfo import summary
from tqdm import tqdm
from util import set_seed, cosine_annealing_lr_with_warmup
import torch.nn.functional as F
import numpy as np
import wandb
import argparse
import math
import time
import random

from torchtoolbox.transform import Cutout

from network_v2 import MLP_FF_Receptive, Conv_FF_model_v2, set_mu_delta
from network_v3 import Resnet_ff_new

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
torch.autograd.set_detect_anomaly(True)

# from pypapi import events as papi_events
# import ctypes
# os.environ['LD_LIBRARY_PATH'] = '/home/papi-install/lib/:' + os.environ.get('LD_LIBRARY_PATH', '')
# papi_lib = ctypes.CDLL('code/papi_test.so')
# event = ctypes.c_int(papi_events.PAPI_SP_OPS)
# papi_lib.initializePAPI(event)
# papi_lib.stopAndRead.restype = ctypes.c_longlong

parser = argparse.ArgumentParser(description='forward-forward-benchmark training args')
parser.add_argument('--device', default='cuda', help='cuda|cpu')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--epochs', type=int, default=500, help='number of epochs to train (default: 200)')
parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 50)')
parser.add_argument('--theta', type=float, default=16, help='layer loss param')
parser.add_argument('--online_visual', type=int, default=0, help='enable wandb')
parser.add_argument('--api_key', default='xxxxxxxxxx', help='wandb api key default is a wrong one')
parser.add_argument('--project_name', default='forward-forward-benchmark-cifarCNN', help='wandb project name')
parser.add_argument('--weight_decay', type=float, default=3e-1, help='weight decay')
parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
parser.add_argument('--label_ext', type=int, default=1, help='label extension width for embedding')
parser.add_argument('--adaptive_lr', type=int, default=1, help='enable adaptive learning rate')
parser.add_argument('--start_lr', type=float, default=0.00005, help='start learning rate for warmup')
parser.add_argument('--warmup_epochs', type=int, default=10, help='warmup epochs')
parser.add_argument('--combo', type=int, default=0, help='selection of model structure')
parser.add_argument('--fopmonitor', default='0', help='papi 1 enable training cost monitoring')
parser.add_argument('--downsample_list', nargs='+', type=int, default=[5,2,3,7] ,help='List of downsampe config')
# stride padding size1 size2

class Opts:
    args = parser.parse_args()
    batch_size = args.batch_size
    theta = args.theta
    lr = args.lr    
    weight_decay = args.weight_decay
    epochs = args.epochs
    seed = args.seed
    device = args.device
    label_ext = args.label_ext
    online_visual = args.online_visual
    api_key = args.api_key
    project_name = args.project_name
    adaptive_lr = args.adaptive_lr
    start_lr = args.start_lr
    warmup_epochs = args.warmup_epochs
    combo = args.combo
    encoded = False
    log_lr = lr
    fopmonitor = args.fopmonitor
    downsample_list = args.downsample_list
    
    if fopmonitor == '1':
        device = 'cpu'


def train(network_ff, optimizer, test_loader, train_loader, opts, linear_clf=None,lc_loss=None, optimizer_cf=None, warmup=False):
    running_loss = 0.
    square_sums = []
    square_sums2 = []
    all_logits = []
    layer_res = [[] for i in range(len(network_ff.blocks))]
    layer_res2 = [[] for i in range(len(network_ff.blocks))]
    random_batch_no = torch.randint(0, len(train_loader),(1,)).item()
    
    # if opts.fopmonitor == '1':
    #     papi_lib.startCounting()
    
    network_ff.train()    
    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training"):
        x, y_ground = x.to(opts.device), y_ground.to(opts.device)
        y_ground_ce = F.one_hot(y_ground, num_classes=10)
        
        batch_size = y_ground.shape[0]  
        x_ce = x
        # generate random one-hot label that is different from y_ground
        random_ints = torch.randint(0,10,(batch_size,)).to(opts.device)
        y_random = random_ints
        y_random_ce = F.one_hot(y_random, num_classes=10)   
        y_ce_zeros = torch.zeros_like(y_ground_ce)   
        x_neg_ce = x
        
        # ----- FF pass ----- #
        posit_sum = 0
        negat_sum = 0
        
        y_ground_ce = y_ground_ce.float()
        y_random_ce = y_random_ce.float()
    
        for layer_idx, layer in enumerate(network_ff.blocks.children()):

            with torch.enable_grad():
                z_pos, x_pos = layer(x_ce, y_ground_ce, opts)
                z_neg, x_neg = layer(x_neg_ce, y_random_ce, opts)
                
                zp_square_mean = z_pos.pow(2).sum(dim=[1,2,3])
                zn_square_mean = z_neg.pow(2).sum(dim=[1,2,3])
    
                #positive_loss = torch.log(1 + torch.exp((-zp_square_mean + opts.theta))).mean()
                positive_loss = nn.Softplus(beta=1,threshold=20)(zp_square_mean - opts.theta).mean()
                #negative_loss = torch.log(1 + torch.exp((zn_square_mean - opts.theta))).mean()
                negative_loss = nn.Softplus(beta=1,threshold=20)(-zn_square_mean + opts.theta).mean()
                
                actor_loss = positive_loss + negative_loss
                actor_loss.backward()
                
                running_loss += actor_loss.detach()
                optimizer[layer_idx].step()
                optimizer[layer_idx].zero_grad()  
                x_ce = x_pos.detach()
                x_neg_ce = x_neg.detach()
                layer_res[layer_idx].append(zp_square_mean.detach().cpu().numpy().mean())
                layer_res2[layer_idx].append(zn_square_mean.detach().cpu().numpy().mean())
    
    # if opts.fopmonitor == '1':
    #     value = ctypes.c_longlong()
    #     total_cnt = papi_lib.stopAndRead(ctypes.byref(value))
    #     print("done, averge fops per batch", total_cnt/(batch_no+1))
    

    layer_res = [np.mean(layer) for layer in layer_res]
    layer_res2 = [np.mean(layer) for layer in layer_res2]
    print("----pos square_sums----", layer_res, "\n----neg square_sums----", layer_res2)
    diff_res = np.array(layer_res) - np.array(layer_res2)
    print("----diff square_sums----", diff_res)
    # abs the diff_res
    diff_res = np.abs(diff_res)
    opts.diff_res = diff_res
    running_loss /= len(train_loader)
    network_ff.eval()
    train_acc = test(network_ff, train_loader, opts, diff_res)
    valid_acc = test(network_ff, test_loader, opts, diff_res)
    
    if opts.online_visual == 1:
        wandb.log({"train_loss": running_loss, "train_acc": train_acc, "valid_acc": valid_acc, \
                   "positive_res": wandb.Histogram(layer_res), "negative_res": wandb.Histogram(layer_res2)
                   })
    
    return running_loss, train_acc, valid_acc

@torch.no_grad()
def test(network, test_loader, opts, diff_res):
    all_goodness = []
    all_labels = []
    for batch_no, (x_test, y_test) in tqdm(enumerate(test_loader), desc="Testing"):
        x_test, y_test = x_test.to(opts.device), y_test.to(opts.device)
        
        goodness_for_labels = []
        y_test_ce = F.one_hot(y_test, num_classes=10)
        #y_test_ce = y_test_ce.float()
        
        for label in range(0,10):
            test_label = torch.ones_like(y_test.argmax(dim=-1)).fill_(label)
            test_label = F.one_hot(test_label, num_classes=10)
            test_label = test_label.float()
            x_with_label = x_test
            test_label_repeated = test_label.repeat(x_with_label.shape[0], 1)
            acts = network(x_with_label, test_label_repeated, opts, diff_res)
            # shape of acts: (batch_size, n_blocks, n_neurons)
                
            goodness = acts.sum(dim=[1])
            # shape of goodness: (batch_size, )
            goodness_for_labels.append(goodness)
              
        # shape of goodness_for_labels: (batch_size, n_labels) 
        goodness_for_labels = torch.stack(goodness_for_labels, dim=1)
        all_goodness.append(goodness_for_labels)
        all_labels.append(y_test_ce.argmax(dim=-1))
        
    all_goodness = torch.cat(all_goodness)
    all_labels = torch.cat(all_labels)    
    correct = all_goodness.argmin(dim=-1).eq(all_labels).sum().item()
    return correct / len(test_loader.dataset)

def load_model(combo):
    #model = Conv_FF_model_v2(combo=combo)
    model = Resnet_ff_new(combo=combo, downsample_en=opts.downsample_list)
    return model

def main(opts):
    set_seed(opts.seed)
    vis_pass = False
    if opts.online_visual == 1:
        import datetime
        os.environ["WANDB_API_KEY"] = opts.api_key 
        wandb_prjname = "FFbenchmarkCNN"+"_"+"CIFAR10"
        for _ in range(3):  # try 3 times
            try:
                time.sleep(random.randint(0, 100))  # try to avoid overwriting
                current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')   
                wandb_runname = "CIFAR_FF_dsResNet"+str(opts.downsample_list)+"_theta_"+str(opts.theta)+"_epochs_" \
                +str(opts.epochs)+"wei_dec"+str(opts.weight_decay)+"_lr_"+str(opts.lr)+"_combo_"+str(opts.combo)+"_current_time_"+str(current_time)
                wandb.init(project=wandb_prjname, name=wandb_runname)
                # rest of your code
                vis_pass = True
                break
            except:
                print("error in mailing...")
                continue
    if vis_pass == False:
        opts.online_visual = 0    
    
    print("----all parameters----")
    args = parser.parse_args()
    for arg in vars(args):
        print(arg, getattr(args, arg))
        
    output_folder = "Output"
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
    # Data augmentation for CIFAR-10
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        #SimpleCutoutPIL(n_holes=1, length=8),
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),), # Updated normalization values
        
        #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.5, 0.5, 0.5)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),),
        #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.5, 0.5, 0.5)),
    ])
    
    # PyTorch provides CIFAR10 dataset
    trainset = CIFAR10(root='./data/cifar10', train=True,
                                            download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=opts.batch_size,
                                            shuffle=True, num_workers=4)

    testset = CIFAR10(root='./data/cifar10', train=False,
                                        download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=opts.batch_size,
                                            shuffle=False, num_workers=4)
    
    
    model = load_model(opts.combo).to(opts.device)
    summary(model, depth=5)
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())
    
    num_params = count_parameters(model)
    peak_mem = 0
    mem_dist = []
    for block in model.blocks.children():
        num_activation = block.activation_num
        num_grad = block.gradient_num
        num_error = block.error_num
        current_mem = (num_params + num_activation + num_grad + num_error)*4/1024  #Assumed float32
        mem_dist.append(current_mem)
        if current_mem > peak_mem:
            peak_mem = current_mem
    print("~~~~~estimated peak mem~~~~~")
    print(">>>:", peak_mem, "KB")
    print(mem_dist)
        
    
    optimizers = [torch.optim.AdamW(block.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
            for block in model.blocks.children()
            ]   
    
    train_loss_hist, valid_loss_hist = [], []
    train_acc_hist, valid_acc_hist = [], []
    
    # initialize the diff_res for goodness recording
    opts.diff_res = np.ones((len(model.blocks),))
    
    best_acc = 0
    start_time = time.time()
    for step in range(0, opts.epochs):
            
        lr = cosine_annealing_lr_with_warmup(step, opts.epochs, opts.lr, opts.warmup_epochs, opts.start_lr, opts.adaptive_lr)
        opts.log_lr = lr
        for opt in optimizers:
            for param_group in opt.param_groups:
                param_group['lr'] = lr
        if step < opts.warmup_epochs:
            warmup = True
        else:
            warmup = False
        train_loss, train_acc_ff, valid_acc_ff = train(model, optimizers, test_loader, train_loader, opts)

        print(f"Step {step:04d} train_loss_FF: {train_loss:.4f} \
                train_acc_FF: {train_acc_ff:.4f} \
                valid_acc_FF: {valid_acc_ff:.4f} lr: {opts.log_lr:.6f}")
        # log the ff accuracy and losses
        train_loss_hist.append(train_loss.cpu())
        train_acc_hist.append(train_acc_ff)
        valid_acc_hist.append(valid_acc_ff)
        
        if valid_acc_ff > best_acc:
            best_acc = valid_acc_ff
            print("Best accuracy so far! ---> {:d}".format(round(best_acc*100)))
            #torch.save(model.state_dict(), 'Output/model_state_InfTest_cifarResff1.pth')
 
        
    print('TRAINING COMPLETE!')
    print('Best accuracy so far! ---> {:f}'.format(best_acc))
    print('elapsed time = {:.2f} minutes'.format((time.time()-start_time)/60))
    print('-'*50)
    
    # plot_filename = "MINST_FF_lr"+str(opts.lr)+"_"+"epochs_"+str(opts.epochs)+"_bs_"+str(opts.batch_size)+"_theta_"+str(opts.theta)
    # + "_nn_strct_"+str(opts.nn_strct)
    # save_plots(plot_filename,train_acc_hist,valid_acc_hist,train_loss_hist,[])
    
if __name__ == '__main__':
    opts = Opts()
    # try:
    main(opts)
    # except Exception as e:
    #     import traceback
    #     traceback.print_exc()
    # finally:
    #     if opts.online_visual == 1:
    #         wandb.finish()