import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import os
import datetime
import wandb
import argparse
from collections import OrderedDict
from tqdm import tqdm
import math
from torch.utils.data import TensorDataset, DataLoader
from model import DS_CNN_BP
from model import DS_CNN_BP_v2
from torchinfo import summary

class Opts:
    parser = argparse.ArgumentParser(description='forward-forward-temporal training args')
    parser.add_argument('--device', default='cuda:0', help='which cuda device to use')
    parser.add_argument('--lr', type=float, default=0.005, help='learning rate')
    parser.add_argument('--epochs', type=int, default=80, help='number of epochs to train (default: 50)')
    parser.add_argument('--batch_size', type=int, default=200, help='input batch size for training (default: 200)')
    parser.add_argument('--theta', type=float, default=8, help='layer loss param')
    parser.add_argument('--online_visual', type=int, default=1, help='enable wandb')
    parser.add_argument('--api_key', default='1b8ca0acf91e90969e46c124523a84dea6a1526d', help='wandb api key default is a wrong one')
    parser.add_argument('--project_name', default='forward-forward-benchmark-temporal-BP', help='wandb project name')
    parser.add_argument('--mode', default='basic', help='BP/GIFF')
    parser.add_argument('--weight_decay', type=float, default=5e-3, help='weight decay')
    parser.add_argument('--label_ext', type=int, default=1, help='label extension width for embedding')
    parser.add_argument('--adaptive_lr', type=int, default=0, help='enable adaptive learning rate')
    parser.add_argument('--start_lr', type=float, default=0.001, help='start learning rate for warmup')
    parser.add_argument('--warmup_epochs', type=int, default=10, help='warmup epochs')
    args = parser.parse_args()
    batch_size = args.batch_size
    theta = args.theta
    lr = args.lr    
    weight_decay = args.weight_decay
    epochs = args.epochs
    device = args.device
    label_ext = args.label_ext
    online_visual = args.online_visual
    api_key = args.api_key
    project_name = args.project_name
    adaptive_lr = args.adaptive_lr
    start_lr = args.start_lr
    warmup_epochs = args.warmup_epochs
    log_lr = lr
    mode = args.mode

def load_data(file_path):
    # Load the data from the .npy files using numpy
    ds = np.load(file_path,allow_pickle=True)
    data = []
    label = []
    for i in ds["data"][0]:
        data.append(i)
    data = np.array(data)
    for j in ds["data"][1]:
        label.append(j)
    label = np.array(label)
    train_tensor = torch.tensor(data, dtype=torch.float32).permute(0, 3, 1, 2)
    label_tensor = torch.tensor(label, dtype=torch.long)
    
    # Create a PyTorch TensorDataset
    dataset = TensorDataset(train_tensor, label_tensor)
    
    return dataset

def cosine_annealing_lr_with_warmup(epoch, num_epochs, initial_lr, warmup_epochs=10, start_lr=0.0001, adaptive=1):
    """
    Cosine annealing learning rate schedule with warm-up.
    """
    max_epochs = num_epochs
    if epoch < warmup_epochs:
        # Linearly increase the learning rate
        lr = start_lr + ((initial_lr - start_lr) / warmup_epochs) * epoch
    else:
        # Shifted epoch value to account for the warm-up period
        shifted_epoch = epoch - warmup_epochs
        max_epochs -= warmup_epochs
        # Cosine annealing after the warm-up period
        if adaptive == 1:
            lr = 0.5 * initial_lr * (1 + math.cos(math.pi * shifted_epoch / max_epochs))
        else:
            lr = initial_lr/math.sqrt(shifted_epoch+1)

    return lr

def load_model(model = "BP"):
    model = DS_CNN_BP_v2()
    return model

def train_bp(model, optimizer, train_loader, loss_fcn, opts):
    running_loss = 0.
    all_outputs = []
    all_labels = []

    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x, y_ground = x.float().to(opts.device), y_ground.to(opts.device)

        with torch.enable_grad():
            ys = model(x)
            loss = loss_fcn(ys, y_ground)
            loss.backward()
            running_loss += loss.detach()

        optimizer.step()
        optimizer.zero_grad()

        all_outputs.append(torch.nn.functional.softmax(ys).argmax(dim=-1))
        all_labels.append(y_ground)

    running_loss /= len(train_loader)
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    correct = all_outputs.eq(all_labels).sum().item()
    return running_loss, correct/len(all_labels)

@torch.no_grad()
def test_bp(model, test_loader, loss_fcn, opts):
    all_outputs = []
    all_labels = []
    test_loss = 0.
    for (x_test, y_test) in test_loader:
        x_test, y_test = x_test.float().to(opts.device), y_test.to(opts.device)
        acts = model(x_test)
        loss = loss_fcn(acts, y_test)
        test_loss += loss.detach()
        all_outputs.append(torch.nn.functional.softmax(acts).argmax(dim=-1))
        all_labels.append(y_test)

    test_loss /= len(test_loader)
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    correct = all_outputs.eq(all_labels).sum().item()
    return test_loss, correct/len(all_labels)

def main(opts):

    # output_folder = "Output"
    # if not os.path.exists(output_folder):
    #     os.mkdir(output_folder)
    

    train_dataset = load_data("./data/mfcc_train_data.npz")
    val_dataset = load_data("./data/mfcc_val_data.npz")
    test_dataset = load_data("./data/mfcc_test_data.npz")
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)
    model = load_model().to(opts.device)
    summary(model)
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())
    num_params = count_parameters(model)
    peak_mem = (num_params + model.gradient_num + model.activation_num + model.error_num)*4/1024
    # Initialize wandb
    if opts.online_visual == 1:
        os.environ["WANDB_API_KEY"] = opts.api_key 
        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')   
        wandb_prjname = opts.project_name
        wandb_runname = "BP_Paras_"+str(num_params)+"_Peak_Mem_"+str(peak_mem)+"KB_epochs_" \
        +str(opts.epochs)+"wei_dec"+str(opts.weight_decay)+"_current_time_"+str(current_time)
        wandb.init(project=wandb_prjname, name=wandb_runname)

    
    
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
    train_loss_hist, valid_loss_hist = [], []
    train_acc_hist, valid_acc_hist = [], []
    best_acc = 0.
    for step in range(opts.epochs):
        for param_group in optimizer.param_groups:
            param_group['lr'] = cosine_annealing_lr_with_warmup(step, opts.epochs, opts.lr, opts.warmup_epochs, opts.start_lr)
        
        opts.log_lr = optimizer.param_groups[0]['lr']
        model.train()
        train_loss, train_acc = train_bp(
            model, optimizer, train_loader, loss_fcn, opts)
        
        model.eval()
        valid_loss, valid_acc = test_bp(model, val_loader, loss_fcn, opts)
        test_loss, test_acc = test_bp(model, test_loader, loss_fcn, opts)
        print(f"Epoch {step:04d} train_loss: {train_loss:.3f} train_acc: {train_acc:.3f} \
              test_loss: {test_loss:.3f} test_acc: {test_acc:.3f} \
                valid_loss: {valid_loss:.3f} valid_acc: {valid_acc:.3f} log_lr: {opts.log_lr:.6f}")
        
        if opts.online_visual == 1:
            wandb.log({"train_loss": train_loss, "train_acc": train_acc, "valid_loss": valid_loss,"valid_acc": valid_acc,
                "test_loss": test_loss, "test_acc": test_acc, "log_lr": opts.log_lr,
                })
            
        if valid_acc > best_acc:
            best_acc = valid_acc
            print("model saved!!!")
            torch.save(model.state_dict(), 'models/bp_DSCNN-fixed_InferenceTest.pth')

        # train_loss_hist.append(train_loss.cpu())
        # train_acc_hist.append(train_acc)
        # valid_loss_hist.append(valid_loss.cpu())
        # valid_acc_hist.append(valid_acc)
        # save the best model till now if we have the least loss in the current epoch
        # save_best_model(valid_loss, step, model, optimizer, loss_fcn)
        # save the loss and accuracy plots
    wandb.finish()
    print('TRAINING COMPLETE!')
    print('='*50)
    # save_plots(f'CIFAR10-BP_lr={str(opts.lr)}',
    #            train_acc_hist, valid_acc_hist,
    #            train_loss_hist, valid_loss_hist)
    # best_model_cp = torch.load(
    #     f'Output/CIFAR10-BP_lr={str(opts.lr)}_best_model.pth')
    # best_model_epoch = best_model_cp['epoch']
    # print(f"Best model was saved at {best_model_epoch} epochs\n")
    # model.load_state_dict(best_model_cp['model_state_dict'])
    # test_loss, test_acc = test(model, val_loader, loss_fcn, opts)
    # print(f"Best model test accuracy: {test_acc:.3f}")
    # print('='*50)
    
if __name__ == '__main__':
    opts = Opts()
    opts.online_visual = 1
    opts.project_name = "forward-forward-benchmark-dscnn-BP-fixed"
    opts.epochs = 36
    main(opts)