
#Just the FF/BP baseline following Hinton's paper 


from torch.utils.data import DataLoader
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchinfo import summary
from tqdm import tqdm
from util import set_seed, save_model, save_plots, SaveBestModel
import torch.nn.functional as F
import numpy as np
import wandb
import argparse
import math

from network_v2 import MLP_FF_Receptive

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
torch.autograd.set_detect_anomaly(True)

def cosine_annealing_lr_with_warmup(epoch, num_epochs, initial_lr, warmup_epochs=10, start_lr=0.0001, adaptive=1):
    """
    Cosine annealing learning rate schedule with warm-up.
    """
    max_epochs = num_epochs
    if epoch < warmup_epochs:
        # Linearly increase the learning rate
        #lr = start_lr + ((initial_lr - start_lr) / warmup_epochs) * epoch
        # sine increase the learning rate
        lr = start_lr + (initial_lr - start_lr) * (1 - math.cos(math.pi * epoch / warmup_epochs)) / 2
    else:
        # Shifted epoch value to account for the warm-up period
        shifted_epoch = epoch - warmup_epochs
        max_epochs -= warmup_epochs
        # Cosine annealing after the warm-up period
        if adaptive == 1:
            lr = 0.5 * initial_lr * (1 + math.cos(math.pi * shifted_epoch / max_epochs))
        else:
            lr = initial_lr

    return lr

# give the parameters in the command line
class Opts:
    parser = argparse.ArgumentParser(description='forward-forward-benchmark training args')
    parser.add_argument('--device', default='cuda', help='cuda|cpu')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--epochs', type=int, default=200, help='number of epochs to train (default: 200)')
    parser.add_argument('--batch_size', type=int, default=200, help='input batch size for training (default: 50)')
    parser.add_argument('--theta', type=float, default=24, help='layer loss param')
    parser.add_argument('--online_visual', type=int, default=0, help='enable wandb')
    parser.add_argument('--api_key', default='xxxxxxxxxxxx', help='wandb api key default is a wrong one')
    parser.add_argument('--project_name', default='forward-forward-benchmark-cifar', help='wandb project name')
    parser.add_argument('--weight_decay', type=float, default=3e-3, help='weight decay')
    parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
    parser.add_argument('--label_ext', type=int, default=1, help='label extension width for embedding')
    parser.add_argument('--adaptive_lr', type=int, default=0, help='enable adaptive learning rate')
    parser.add_argument('--start_lr', type=float, default=0.00001, help='start learning rate for warmup')
    parser.add_argument('--warmup_epochs', type=int, default=0, help='warmup epochs')
    
    args = parser.parse_args()
    batch_size = args.batch_size
    theta = args.theta
    lr = args.lr    
    weight_decay = args.weight_decay
    epochs = args.epochs
    seed = args.seed
    device = args.device
    label_ext = args.label_ext
    online_visual = args.online_visual
    api_key = args.api_key
    project_name = args.project_name
    adaptive_lr = args.adaptive_lr
    start_lr = args.start_lr
    warmup_epochs = args.warmup_epochs
    
    log_lr = lr
    

def extend_image_append_label(x,y,opts):
    x_ce = torch.nn.functional.pad(x,(1,1,1,1)) # Create a BSxIMsizexIMsize empty matrix to embed the label around the image
    ll = torch.zeros(x.shape[0],x_ce.shape[-1],x_ce.shape[-1]).to(opts.device)
    if y.dim() == 1:
        y = y.unsqueeze(0)
        y = y.expand(x.shape[0], -1)
        
    for i in range(opts.label_ext):
        ll[:,12:22,0+i] = y
        ll[:,12:22,33-i] = y
        ll[:,0+i,12:22] = y
        ll[:,33-i,12:22] = y
    ll = ll.unsqueeze(1).repeat(1, 3, 1, 1)
    x_ce = x_ce + ll # Insert the label next to image
    return x_ce

def partition_and_reorder(images, n_blocks):
    # Assuming images size is (batch_size, 3, 32, 32)
    assert images.shape[-2] % n_blocks == 0 and images.shape[-1] % n_blocks == 0, \
        "Image size should be divisible by the number of blocks"

    block_size = images.shape[-1] // n_blocks

    # Partition the image into blocks using the unfold function
    # The size of blocks becomes (batch_size, 3, n_blocks, n_blocks, block_size, block_size)
    blocks = images.unfold(-2, block_size, block_size).unfold(-1, block_size, block_size)

    # Reshape the blocks tensor for easier shuffling
    blocks = blocks.reshape(*blocks.shape[:-4], -1, block_size, block_size)

    # Create a deterministic permutation of indices
    # Swap adjacent indices
    n = blocks.shape[-3]
    indices = torch.arange(n).view(-1, 2)
    indices = torch.cat((indices[:,1:], indices[:,:1]), axis=1).view(-1)

    # Use the permuted indices to reorder the blocks
    reordered_blocks = blocks[:, :, indices, :, :]

    # Reshape the reordered blocks back to the image shape
    reordered_image = reordered_blocks.view(*reordered_blocks.shape[:-3], n_blocks, n_blocks, block_size, block_size)
    
    # Rearrange axes to get back to (batch_size, 3, 32, 32) shape
    reordered_image = reordered_image.permute(0, 1, 2, 4, 3, 5).contiguous().view(*images.shape)

    return reordered_image




def train(network_ff, optimizer, test_loader, train_loader, opts, linear_clf=None,lc_loss=None, optimizer_cf=None):
    running_loss = 0.
    square_sums = []
    square_sums2 = []
    all_logits = []
    
    layer_res = np.zeros(len(network_ff.blocks))
    
    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training"):
        x, y_ground = x.to(opts.device), y_ground.to(opts.device)
        y_ground_ce = F.one_hot(y_ground, num_classes=10)
        batch_size = y_ground.shape[0]  
        #x_ce = extend_image_append_label(x,y_ground_ce,opts)
        x_ce = x
        # generate random one-hot label that is different from y_ground
        random_ints = torch.randint(0,10,(batch_size,)).to(opts.device)
        y_random = (random_ints + y_ground) % 10
        y_random = random_ints
        y_random_ce = F.one_hot(y_random, num_classes=10)   
        y_ce_zeros = torch.zeros_like(y_ground_ce)   
        #x_neg_ce = extend_image_append_label(x,y_random_ce,opts)
        x_neg_ce = x
        
        # ----- FF pass ----- #
        posit_sum = 0
        negat_sum = 0
        for layer_idx, layer in enumerate(network_ff.blocks.children()):

            with torch.enable_grad():
                # if layer_idx == 0:
                #     z_pos = F.normalize(x_ce, dim=[2,3])
                #     z_neg = F.normalize(x_neg_ce, dim=[2,3])
                z_pos = layer(x_ce, y_ground_ce, opts)
                z_neg = layer(x_neg_ce, y_random_ce, opts)
                
                zp_square_mean = z_pos.pow(2).mean(dim=[1,2,3])
                zn_square_mean = z_neg.pow(2).mean(dim=[1,2,3])
                
                posit_sum += zp_square_mean.mean().item()
                negat_sum += zn_square_mean.mean().item()
                positive_loss = torch.log(1 + torch.exp((zp_square_mean - opts.theta))).mean()
                #positive_loss = nn.Softplus(beta=1,threshold=20)(-zp_square_mean + opts.theta).mean()
                negative_loss = torch.log(1 + torch.exp((-zn_square_mean + opts.theta))).mean()
                #negative_loss = nn.Softplus(beta=1,threshold=20)((zn_square_mean - opts.theta)).mean()
                
                actor_loss = positive_loss + negative_loss
                actor_loss.backward()
                
                running_loss += actor_loss.detach()
                
                #torch.nn.utils.clip_grad_norm_(layer.parameters(), 1)
                #torch.nn.utils.clip_grad_norm_(layer.parameters(), 0.5)
                optimizer[layer_idx].step()
                optimizer[layer_idx].zero_grad()  
                #x_ce = extend_image_append_label(z_pos.detach(),y_ground_ce,opts)
                #x_neg_ce = extend_image_append_label(z_neg.detach(),y_ground_ce,opts)
                x_ce = z_pos.detach()
                x_neg_ce = z_neg.detach()
            
            if batch_no == 10:
                layer_res[layer_idx] = zp_square_mean.detach().cpu().numpy().mean()
    
                # test for reordering
                #x_ce = partition_and_reorder(z_pos.detach(), 2)
                #x_neg_ce = partition_and_reorder(z_neg.detach(), 2)
                
        square_sums.append(posit_sum)
        square_sums2.append(negat_sum)
            
    print("----square_sums----", np.mean(square_sums), np.mean(square_sums2))
    print("layer_res", layer_res)
    running_loss /= len(train_loader)
    #train_acc = 0
    train_acc = test(network_ff, train_loader, opts)
    valid_acc = test(network_ff, test_loader, opts)
    
    if opts.online_visual == 1:
        wandb.log({"train_loss": running_loss, "train_acc": train_acc, "valid_acc": valid_acc, \
                   "positive_mean": np.mean(square_sums), "negative_mean": np.mean(square_sums2)
                   })
    
    return running_loss, train_acc, valid_acc

@torch.no_grad()
def test(network, test_loader, opts):
    all_goodness = []
    all_labels = []
    
    for (x_test, y_test) in tqdm(test_loader, desc="Testing"):
        x_test, y_test = x_test.to(opts.device), y_test.to(opts.device)
        
        goodness_for_labels = []
        y_test_ce = F.one_hot(y_test, num_classes=10)
        
        for label in range(0,10):
            test_label = torch.ones_like(y_test.argmax(dim=-1)).fill_(label)
            test_label = F.one_hot(test_label, num_classes=10)
            #x_with_label = extend_image_append_label(x_test,test_label,opts)
            x_with_label = x_test
            acts = network(x_with_label, test_label, opts)
            # shape of acts: (batch_size, n_blocks, n_neurons)
            goodness = acts.pow(2).mean(dim=[1])
            # shape of goodness: (batch_size, )
            goodness_for_labels.append(goodness)
              
        # shape of goodness_for_labels: (batch_size, n_labels) 
        goodness_for_labels = torch.stack(goodness_for_labels, dim=1)
        all_goodness.append(goodness_for_labels)
        all_labels.append(y_test_ce.argmax(dim=-1))
        
    all_goodness = torch.cat(all_goodness)
    all_labels = torch.cat(all_labels)    
    correct = all_goodness.argmin(dim=-1).eq(all_labels).sum().item()
    return correct / len(test_loader.dataset)

def load_model():
    model = MLP_FF_Receptive()
    return model

def main(opts):
    set_seed(opts.seed)
    
    if opts.online_visual == 1:
        import datetime
        os.environ["WANDB_API_KEY"] = opts.api_key 
        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')   
        wandb_prjname = "FFbenchmark"+"_"+"CIFAR10"
        wandb_runname = "CIFAR_FF_lr"+str(opts.lr)+"_theta_"+str(opts.theta)+"_epochs_" \
        +str(opts.epochs)+"wei_dec"+str(opts.weight_decay)+"_current_time_"+str(current_time)
        wandb.init(project=wandb_prjname, name=wandb_runname)
    
    output_folder = "Output"
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
        
    #stats = ((0.4914, 0.4822, 0.3465), (0.2023, 0.1994, 0.2010))
    stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    # standard cast into Tensor and pixel values normalization in [-1, 1] range
    transform = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Normalize(*stats, inplace=True),
    ])

    # extra transfrom for the training data, in order to achieve better performance
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Normalize(*stats, inplace=True),
        transforms.RandomCrop(32, padding=0), 
        transforms.RandomHorizontalFlip(), 
    ])

    train_loader = DataLoader(CIFAR10(root='data/cifar10', train=True, download=True, transform=train_transform),
                            batch_size=opts.batch_size,
                            shuffle=True
                            #drop_last=True
                           )
    test_loader = DataLoader(CIFAR10(root='data/cifar10', train=False, download=True, transform=transform),
                            batch_size=opts.batch_size,
                            shuffle=False
                            #drop_last=True
                            )
    model = load_model().to(opts.device)
    summary(model)
    
    optimizers = [torch.optim.AdamW(block.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
            for block in model.blocks.children()
            ]   
    
    schedulers = [torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=opts.epochs, eta_min=opts.start_lr) \
                 for opt in optimizers]
    
    train_loss_hist, valid_loss_hist = [], []
    train_acc_hist, valid_acc_hist = [], []
    
    best_acc = 0
    for step in range(0, opts.epochs):

        lr = cosine_annealing_lr_with_warmup(step, opts.epochs, opts.lr, opts.warmup_epochs, opts.start_lr, opts.adaptive_lr)
        opts.log_lr = lr
        for opt in optimizers:
            for param_group in opt.param_groups:
                param_group['lr'] = lr
        
        train_loss, train_acc_ff, valid_acc_ff = train(model, optimizers, test_loader, train_loader, opts)

        print(f"Step {step:04d} train_loss_FF: {train_loss:.4f} \
                train_acc_FF: {train_acc_ff:.4f} \
                valid_acc_FF: {valid_acc_ff:.4f} lr: {opts.log_lr:.6f}")
        # log the ff accuracy and losses
        train_loss_hist.append(train_loss.cpu())
        train_acc_hist.append(train_acc_ff)
        valid_acc_hist.append(valid_acc_ff)
        
        if valid_acc_ff > best_acc:
            best_acc = valid_acc_ff
            print("Best accuracy so far! ---> {:d}".format(round(best_acc*100)))
 
        
    print('TRAINING COMPLETE!')
    print('-'*50)
    
    # plot_filename = "MINST_FF_lr"+str(opts.lr)+"_"+"epochs_"+str(opts.epochs)+"_bs_"+str(opts.batch_size)+"_theta_"+str(opts.theta)
    # + "_nn_strct_"+str(opts.nn_strct)
    # save_plots(plot_filename,train_acc_hist,valid_acc_hist,train_loss_hist,[])
    
if __name__ == '__main__':
    opts = Opts()
    # try:
    main(opts)
    # except Exception as e:
    #     import traceback
    #     traceback.print_exc()
    # finally:
    #     if opts.online_visual == 1:
    #         wandb.finish()