from torch.utils.data import DataLoader, random_split, Subset
import os
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import CIFAR10
from torchinfo import summary
from tqdm import tqdm
from util import set_seed, save_model, save_plots, SaveBestModel, cosine_annealing_lr_with_warmup
import torch.nn.functional as F
import numpy as np
import wandb
import argparse
import math
import time
import random
from network_v3 import FF_mobilenet_v1

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
torch.autograd.set_detect_anomaly(True)

# from pypapi import events as papi_events
# import ctypes
# os.environ['LD_LIBRARY_PATH'] = '/home/yourname/papi-install/lib/:' + os.environ.get('LD_LIBRARY_PATH', '')
# papi_lib = ctypes.CDLL('code/papi_test.so')
# event = ctypes.c_int(papi_events.PAPI_SP_OPS)
# papi_lib.initializePAPI(event)

parser = argparse.ArgumentParser(description='forward-forward-benchmark vww training args')
parser.add_argument('--device', default='cuda', help='cuda|cpu')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train (default: 200)')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 50)')
parser.add_argument('--theta', type=float, default=4, help='layer loss param')
parser.add_argument('--online_visual', type=int, default=0, help='enable wandb')
parser.add_argument('--api_key', default='eec626411e7ff3f4c229c1302489a9df4ab713f9', help='wandb api key default is a wrong one')
parser.add_argument('--dataset_dir', default='data/vw_coco2014_96', help='dataset dir')
parser.add_argument('--weight_decay', type=float, default=3e-1, help='weight decay')
parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
parser.add_argument('--adaptive_lr', type=int, default=1, help='enable adaptive learning rate')
parser.add_argument('--start_lr', type=float, default=0.00005, help='start learning rate for warmup')
parser.add_argument('--warmup_epochs', type=int, default=5, help='warmup epochs')
parser.add_argument('--combo', type=int, default=0, help='selection of model structure')
parser.add_argument('--label_len', type=int, default=2, help='label_length')
parser.add_argument('--fopmonitor', default='0', help='papi 1 enable monitoring')
parser.add_argument('--save_model', type=int, default=0, help='save model or not')
parser.add_argument('--pool_list', nargs='+', type=int, default=[6, 2, 2, 2, 2], help='pool list')

class Opts:
    args = parser.parse_args()
    batch_size = args.batch_size
    theta = args.theta
    lr = args.lr    
    weight_decay = args.weight_decay
    epochs = args.epochs
    seed = args.seed
    device = args.device
    online_visual = args.online_visual
    api_key = args.api_key
    adaptive_lr = args.adaptive_lr
    start_lr = args.start_lr
    warmup_epochs = args.warmup_epochs
    combo = args.combo
    label_len = args.label_len
    log_lr = lr
    dataset_dir = args.dataset_dir
    fopmonitor = args.fopmonitor
    if fopmonitor == '1':
        device = 'cpu'
    save_model = args.save_model
    pool_list = args.pool_list

def load_model(opts):
    #model = Conv_FF_model_v2(combo=combo)
    model = FF_mobilenet_v1(combo=opts.combo, label_len=opts.label_len, pool_list=opts.pool_list)
    return model

def train(network_ff, optimizer, test_loader, train_loader, opts):
    running_loss = 0.
    layer_res = [[] for i in range(len(network_ff.blocks))]
    layer_res2 = [[] for i in range(len(network_ff.blocks))]
    
    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training"):
        x, y_ground = x.to(opts.device), y_ground.to(opts.device)
        y_ground_ce = F.one_hot(y_ground, num_classes=2)
        batch_size = y_ground.shape[0]  
        x_ce = x.float()
        y_random = (1 + y_ground) % 2
        y_random_ce = F.one_hot(y_random, num_classes=2)
        x_neg_ce = x.float()
        y_ground_ce = y_ground_ce.float()
        y_random_ce = y_random_ce.float()
        
        # if opts.fopmonitor == '1':
        #     papi_lib.startCounting()
        for layer_idx, layer in enumerate(network_ff.blocks.children()):
            with torch.enable_grad():
                z_pos, x_pos = layer(x_ce, y_ground_ce, opts)
                z_neg, x_neg = layer(x_neg_ce, y_random_ce, opts)
                
                zp_square_mean = z_pos.pow(2).sum(dim=[1,2,3])
                zn_square_mean = z_neg.pow(2).sum(dim=[1,2,3])

                positive_loss = nn.Softplus(beta=1,threshold=40)(zp_square_mean - opts.theta_list[layer_idx]).mean()
                negative_loss = nn.Softplus(beta=1,threshold=40)((-zn_square_mean + opts.theta_list[layer_idx])).mean()
                
                actor_loss = positive_loss + negative_loss
                actor_loss.backward()
                
                running_loss += actor_loss.detach()
                optimizer[layer_idx].step()
                optimizer[layer_idx].zero_grad()  
                x_ce = x_pos.detach()
                x_neg_ce = x_neg.detach()
                layer_res[layer_idx].append(zp_square_mean.detach().cpu().numpy().mean())
                layer_res2[layer_idx].append(zn_square_mean.detach().cpu().numpy().mean())
        
        # for layer_idx, layer in enumerate(network_ff.blocks.children()):
        #     z_pos, x_pos = layer(x_ce, y_ground_ce, opts)
        #     zp_square_mean = z_pos.pow(2).sum(dim=[1,2,3])
        #     positive_loss = nn.Softplus(beta=1,threshold=40)(zp_square_mean - opts.theta_list[layer_idx]).mean()
        #     positive_loss.backward()
        #     running_loss += positive_loss.detach()
        #     optimizer[layer_idx].step()
        #     optimizer[layer_idx].zero_grad()
        #     x_ce = x_pos.detach()
        #     layer_res[layer_idx].append(zp_square_mean.detach().cpu().numpy().mean())
              
        # for layer_idx, layer in enumerate(network_ff.blocks.children()):
        #     z_neg, x_neg = layer(x_neg_ce, y_random_ce, opts)
        #     zn_square_mean = z_neg.pow(2).sum(dim=[1,2,3])
        #     negative_loss = nn.Softplus(beta=1,threshold=40)((-zn_square_mean + opts.theta_list[layer_idx])).mean()
        #     negative_loss.backward()
        #     running_loss += negative_loss.detach()
        #     optimizer[layer_idx].step()
        #     optimizer[layer_idx].zero_grad()
        #     x_neg_ce = x_neg.detach()
        #     layer_res2[layer_idx].append(zn_square_mean.detach().cpu().numpy().mean())
        
        # if opts.fopmonitor == '1':
        #     value = ctypes.c_longlong()
        #     papi_lib.stopAndRead(ctypes.byref(value))
        #     print("done")
    
    
            
    layer_res = [np.mean(layer) for layer in layer_res]
    layer_res2 = [np.mean(layer) for layer in layer_res2]
    print("----pos square_sums----", layer_res, "\n----neg square_sums----", layer_res2)
    diff_res = np.array(layer_res) - np.array(layer_res2)
    print("----diff square_sums----", diff_res)
    # abs the diff_res
    diff_res = np.abs(diff_res)
    opts.diff_res = diff_res
    running_loss /= len(train_loader)
    if opts.online_visual == 1:
        train_acc = test(network_ff, train_loader, opts, diff_res)
    else:
        train_acc = 0
    valid_acc = test(network_ff, test_loader, opts, diff_res)
    if opts.online_visual == 1:
        wandb.log({"train_loss": running_loss, "train_acc": train_acc, "valid_acc": valid_acc})
    return running_loss, train_acc, valid_acc

@torch.no_grad()
def test(network, test_loader, opts, diff_res):
    all_goodness = []
    all_labels = []
    
    for batch_no, (x_test, y_test) in tqdm(enumerate(test_loader), desc="Testing"):
        x_test, y_test = x_test.to(opts.device), y_test.to(opts.device)
        goodness_for_labels = []
        y_test_ce = F.one_hot(y_test, num_classes=2)
        
        for label in range(0,2):
            test_label = torch.ones_like(y_test.argmax(dim=-1)).fill_(label)
            test_label = F.one_hot(test_label, num_classes=2)
            test_label = test_label.float()
            x_with_label = x_test
            test_label_repeated = test_label.repeat(x_with_label.shape[0], 1)
            acts = network(x_with_label, test_label_repeated, opts, diff_res)
            goodness = acts.sum(dim=[1])
            goodness_for_labels.append(goodness)
        
        goodness_for_labels = torch.stack(goodness_for_labels, dim=1)
        all_goodness.append(goodness_for_labels)
        all_labels.append(y_test_ce.argmax(dim=-1))
    all_goodness = torch.cat(all_goodness)
    all_labels = torch.cat(all_labels)    
    correct = all_goodness.argmin(dim=-1).eq(all_labels).sum().item()
    return correct / len(test_loader.dataset)
        
def main(opts):
    set_seed(opts.seed)
    
    ### wandb visualization ###
    vis_pass = False
    if opts.online_visual == 1:
        import datetime
        os.environ["WANDB_API_KEY"] = opts.api_key 
        wandb_prjname = "FFbenchmarkCNN"+"_"+"VWW"
        for _ in range(3):  # try 3 times
            try:
                #time.sleep(random.randint(0, 100))  # try to avoid overwriting
                current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')   
                wandb_runname = "VWW_FF_lr"+str(opts.lr)+"_theta_"+str(opts.theta)+"_epochs_" \
                +str(opts.epochs)+"wei_dec"+str(opts.weight_decay)+"_combo_"+str(opts.combo)+"_current_time_"+str(current_time)
                wandb.init(project=wandb_prjname, name=wandb_runname)
                # rest of your code
                vis_pass = True
                break
            except:
                print("error in mailing...")
                continue
    if vis_pass == False:
        opts.online_visual = 0    
    
    print("----all parameters----")
    args = parser.parse_args()
    for arg in vars(args):
        print(arg, getattr(args, arg))
        
    output_folder = "Output"
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
    #stats = ((0.4914, 0.4822, 0.3465), (0.2023, 0.1994, 0.2010))
    stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    
    BASE_DIR = opts.dataset_dir
    validation_split = 0.1
    IMAGE_SIZE = 96
    
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # extra transfrom for the training data, in order to achieve better performance
    transform_train = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.9, 1.1), ratio=(0.9, 1.1)), 
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # For RGB channels, adjust if necessary
    ])
    
    full_dataset = datasets.ImageFolder(root=BASE_DIR)
        
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    train_dataset.dataset.transform = transform_train
    val_dataset.dataset.transform = transform_val
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)
    
    model = load_model(opts).to(opts.device)
    summary(model, depth=5)
    
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())
    
    num_params = count_parameters(model)
    peak_mem = 0
    mem_dist = []
    est_batchsizes = [1, 32, 128]
    for est_batchsize in est_batchsizes:
        for block in model.blocks.children():
            num_activation = block.activation_num
            num_grad = block.gradient_num
            num_error = block.error_num
            current_mem = (num_params + (num_grad + num_activation + num_error)*est_batchsize)*4/1024  #Assumed float32
            mem_dist.append(current_mem)
            if current_mem > peak_mem:
                peak_mem = current_mem
        print("~~~~~estimated peak mem~~~~~", est_batchsize)
        print(">>>:", peak_mem, "KB")
        print(mem_dist)
    
    optimizers = [torch.optim.AdamW(block.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
            for block in model.blocks.children()
            ]  
    train_loss_hist, valid_loss_hist = [], []
    train_acc_hist, valid_acc_hist = [], []
    
    # initialize the diff_res for goodness recording
    opts.diff_res = np.ones((len(model.blocks),))
    opts.theta_list = [opts.theta for i in range(len(model.blocks))]
    
    best_acc = 0
    start_time = time.time()
    
    for step in range(0, opts.epochs):
        lr = cosine_annealing_lr_with_warmup(step, opts.epochs, opts.lr, opts.warmup_epochs, opts.start_lr, opts.adaptive_lr)
        opts.log_lr = lr
        for opt in optimizers:
            for param_group in opt.param_groups:
                param_group['lr'] = lr
        train_loss, train_acc_ff, valid_acc_ff = train(model, optimizers, val_loader, train_loader, opts)
        print(f"Step {step:04d} train_loss_FF: {train_loss:.4f} \
                    train_acc_FF: {train_acc_ff:.4f} \
                    valid_acc_FF: {valid_acc_ff:.4f} lr: {opts.log_lr:.6f}")
        # log the ff accuracy and losses
        train_loss_hist.append(train_loss.cpu())
        train_acc_hist.append(train_acc_ff)
        valid_acc_hist.append(valid_acc_ff)
            
        if valid_acc_ff > best_acc:
            best_acc = valid_acc_ff
            print("Best accuracy so far! ---> {:d}".format(round(best_acc*100)))
            # save the best model till now if we have the least loss in the current epoch
            # save_model(model, step, best_acc, opts)
            if opts.save_model == 1:
                torch.save(model.state_dict(), 'Output/model_state_InfTest_VWWmb0.pth')
    
    print('TRAINING COMPLETE!')
    print('Best accuracy so far! ---> {:f}'.format(best_acc))
    print('elapsed time = {:.2f} minutes'.format((time.time()-start_time)/60))
    print('-'*50)
    
if __name__ == '__main__':
    opts = Opts()
    # try:
    main(opts)