#test
import os
import numpy as np
import torch
import torch.nn as nn
import random
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm
import torch.nn.functional as F
from util import accuracy, set_seed, save_model, save_plots, SaveBestModel
import argparse
from torchinfo import summary

from network_v3 import mnist_conv2d

def extend_image_append_label(x,y,opts):
    x_ce = torch.nn.functional.pad(x,(1,1,1,1)) # Create a BSxIMsizexIMsize empty matrix to embed the label around the image
    ll = torch.zeros(x.shape[0],x_ce.shape[-1],x_ce.shape[-1]).to(opts.device)
    if y.dim() == 1:
        y = y.unsqueeze(0)
        y = y.expand(x.shape[0], -1)
        
    ll[:,10:20,0] = y
    ll[:,10:20,29] = y
    ll[:,0,10:20] = y
    ll[:,29,10:20] = y
    ll = ll.unsqueeze(1).repeat(1, 1, 1, 1)
    x_ce = x_ce + ll # Insert the label next to image
    return x_ce

class MNISTDataset(Dataset):
    def __init__(self, train=True, transform=None):
        """
        Args:
            train (bool): Indicates whether to load the training or test set.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.mnist_data = MNIST(root='./data', train=train, download=True,
                                         transform=transform)
        self.transform = transform

    def __len__(self):
        return len(self.mnist_data)

    def __getitem__(self, index):
        image, label = self.mnist_data[index]
        return image, label
    
def load_mnist_data(batchsize = 128, train=True, val=True):
    """
    Load MNIST data into train and test sets.
    Returns a list of data loaders
    """
    print("start building MNIST data loaders")
    ds = []
    if train:
        Mnist_train_data = MNISTDataset(train=True, transform=transforms.ToTensor())
        Mnist_train_loader = DataLoader(Mnist_train_data, batch_size=batchsize, shuffle=True)
        ds.append(Mnist_train_loader)
    if val:
        Mnist_test_data = MNISTDataset(train=False, transform=transforms.ToTensor())
        Mnist_test_loader = DataLoader(Mnist_test_data, batch_size=batchsize, shuffle=False)
        ds.append(Mnist_test_loader)
    ds = ds[0] if len(ds) == 1 else ds 
    # for data_loader in ds:
    #     data_loader.dataset.mnist_data.targets = to_one_hot(data_loader.dataset.mnist_data.targets, num_classes=10)
    return ds

def load_model(opts):
    model = mnist_conv2d(opts.combo)
    return model


@torch.no_grad()
def test(network, test_loader, opts, diff_res=None):
    all_outputs = []
    all_labels = []

    for (x_test, y_test) in test_loader:
        x_test, y_test = x_test.float().to(opts.device), y_test.float().to(opts.device)
        
        #x_test = x_test.view(x_test.shape[0], -1)
        x_temp = x_test
        acts_for_labels = []
        #################################
        # Forward Pass
        for label in range(10):
            test_label = torch.ones_like(y_test.argmax(dim=-1)).fill_(label)
            test_label = F.one_hot(test_label, num_classes=10).float()
            test_label_repeated = test_label.repeat(x_test.shape[0], 1)
            if opts.combo != 0:
                x_temp = extend_image_append_label(x_test, test_label, opts)
            acts = network(x_temp, test_label_repeated, opts, diff_res)
            acts = acts.sum(dim=[1])
            acts_for_labels.append(acts)
            
        acts_for_labels = torch.stack(acts_for_labels, dim=1)
        all_outputs.append(acts_for_labels)
        all_labels.append(y_test)
        
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    correct = all_outputs.argmin(dim=-1).eq(all_labels).sum().item()

    return correct/len(all_labels)

def train(network, optimizer, train_loader, valid_loader, opts):
    running_loss = 0.
    layer_res = [[] for i in range(len(network.blocks))]
    layer_res2 = [[] for i in range(len(network.blocks))]
    
    
    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x, y_ground = x.to(opts.device), y_ground.to(opts.device)
        
        #x = x.view(opts.batch_size, -1)# x.size = batch x 784

        h = x # h.size = batch x 784
        c = F.one_hot(y_ground,num_classes=10).float() # c.size = batch x 10
        
        y_rand = torch.randint(0, 10, (opts.batch_size,), device=opts.device)
        idx = torch.where(y_rand != y_ground.argmax(dim=-1))  # incorrect labels
        y_rand_one_hot = F.one_hot(y_rand, num_classes=10).float()
        c_rand = y_rand_one_hot
        h_neg = x.detach().clone()
        
        if opts.combo != 0:
            h = extend_image_append_label(h, c, opts)
            h_neg = extend_image_append_label(h_neg, y_rand_one_hot, opts)
        
        for layer_idx, layer in enumerate(network.blocks.children()):
            with torch.enable_grad():
                f_o, h_o = layer(h, c, opts)
                f_n, h_n = layer(h_neg, c_rand, opts)
                
                zp_square_sum = f_o.pow(2).sum(dim=[1,2,3])
                zn_square_sum = f_n.pow(2).sum(dim=[1,2,3])
                
                positive_loss = nn.Softplus(beta=1,threshold=20)(zp_square_sum - opts.theta).mean()
                negative_loss = nn.Softplus(beta=1,threshold=20)(-zn_square_sum + opts.theta).mean()
                
                
                act_loss = positive_loss + negative_loss
                act_loss.backward()
                
                optimizer[layer_idx].step()
                #check whole model's gradient
                # for name, param in network.named_parameters():
                #     if param.requires_grad and param.grad is not None:
                #         sum_grad, lname = param.grad.data.sum(), name
                #         print(lname, sum_grad)
                # print("it shows the optimized layers, if not mentioned, it means the layer is not optimized in this step")
                
                optimizer[layer_idx].zero_grad()
                running_loss += act_loss.detach()
                
                layer_res[layer_idx].append(zp_square_sum.detach().cpu().numpy().mean())
                layer_res2[layer_idx].append(zn_square_sum.detach().cpu().numpy().mean())
                
            h = h_o.detach()
            h_neg = h_n.detach()
            
    layer_res = [np.mean(layer) for layer in layer_res]
    layer_res2 = [np.mean(layer) for layer in layer_res2]
    print("----pos square_sums----", layer_res, "\n----neg square_sums----", layer_res2)
    diff_res = np.array(layer_res) - np.array(layer_res2)
    print("----diff square_sums----", diff_res)
    # abs the diff_res
    diff_res = np.abs(diff_res)
    opts.diff_res = diff_res
    running_loss /= len(train_loader)
    
    
    if opts.online_visual == 1:
        train_acc = test(network, train_loader, opts, diff_res)
    else:
        train_acc = 0
    valid_acc = test(network, valid_loader, opts, diff_res)
    return running_loss/len(train_loader), train_acc, valid_acc

parser = argparse.ArgumentParser(description='forward-forward-benchmark training args')
parser.add_argument('--device', default='cuda', help='cuda|cpu')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train (default: 200)')
parser.add_argument('--batch_size', type=int, default=200, help='input batch size for training (default: 50)')
parser.add_argument('--theta', type=float, default=8, help='layer loss param')
parser.add_argument('--online_visual', type=int, default=0, help='enable wandb')
parser.add_argument('--api_key', default='xxxxxxxxxxxxxxxxxxxxx', help='wandb api key default is a wrong one')
parser.add_argument('--weight_decay', type=float, default=3e-3, help='weight decay')
parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
parser.add_argument('--adaptive_lr', type=int, default=1, help='enable adaptive learning rate')
parser.add_argument('--start_lr', type=float, default=0.00005, help='start learning rate for warmup')
parser.add_argument('--warmup_epochs', type=int, default=5, help='warmup epochs')

class Opts:
    args = parser.parse_args()
    lr = args.lr
    weight_decay = args.weight_decay
    epochs = args.epochs
    theta = args.theta
    online_visual = args.online_visual
    seed = args.seed
    device = args.device
    
    batch_size = args.batch_size
    combo = 0

def main(opts):
    set_seed(opts.seed)
    train_loader, valid_loader = load_mnist_data(batchsize=opts.batch_size)

    model = load_model(opts).to(opts.device)
    summary(model)
    
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())
    
    num_params = count_parameters(model)
    peak_mem = 0
    mem_dist = []
    for block in model.blocks.children():
        num_activation = block.activation_num
        num_grad = block.gradient_num
        num_error = block.error_num
        current_mem = (num_params + num_activation + num_grad + num_error)*4/1024  #Assumed float32
        mem_dist.append(current_mem)
        if current_mem > peak_mem:
            peak_mem = current_mem
    print("~~~~~estimated peak mem~~~~~")
    print(">>>:", peak_mem, "KB")
    print(mem_dist)
    
    optimizers = [torch.optim.Adam(block.parameters(), lr=opts.lr)
                for block in model.blocks.children()
                ]
    train_loss_hist = []
    # valid_loss_hist = []
    train_acc_hist, valid_acc_hist = [], []
    best_acc = 0
    for step in range(opts.epochs):
        train_loss, train_acc, valid_acc = train(model, optimizers, train_loader, valid_loader, opts)
        # valid_loss, valid_acc = test(model, valid_loader, loss_fcn, opts)
        print(f"Epoch {step:04d} train_loss: {train_loss:.3f} \
                train_acc: {train_acc:.3f} \
                valid_acc: {valid_acc:.3f}")
        train_loss_hist.append(train_loss.cpu())
        train_acc_hist.append(train_acc)
        # valid_loss_hist.append(valid_loss.cpu())
        valid_acc_hist.append(valid_acc)
        
        if valid_acc > best_acc:
            best_acc = valid_acc
            torch.save(model.state_dict(), 'Output/model_state_InfTest_MNISTFFcnn.pth')
        # save the best model till now if we have the least loss in the current epoch
        # save_best_model(valid_loss, step, model, optimizer, loss_fcn)
        # save the loss and accuracy plots
        if opts.online_visual == 1:
            wandb.log({"train_loss": train_loss, "train_acc": train_acc, "valid_acc": valid_acc})
 

    print('TRAINING COMPLETE!')
    print('='*50)
    plot_filename = "MINST_FFinc_lr"+str(opts.lr)+"_"+"epochs_"+str(opts.epochs)+"_bs_"+str(opts.batch_size)+"_theta_"+str(opts.theta)
    + "_nn_strct_"+str(opts.nn_strct)
    save_plots(plot_filename,train_acc_hist,valid_acc_hist,train_loss_hist,[])

if __name__ == '__main__':
    opts = Opts()
    
    if opts.online_visual == 1:
        import wandb
        os.environ["WANDB_API_KEY"] = opts.api_key    
        wandb_prjname = "FFbenchmark"+"_"+"MNIST"
        wandb_runname = "MINST_FFinc_lr"+str(opts.lr)+"_theta_"+str(opts.theta)+"_epochs_"+str(opts.epochs)+"_bs_"+str(opts.batch_size)+"_nn_strct_"+str(opts.nn_strct)
        wandb.init(project=wandb_prjname, name=wandb_runname)
    
    main(opts)
    
    if opts.online_visual == 1:
        wandb.finish()