import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import os
import datetime
import wandb
import argparse
from collections import OrderedDict
from tqdm import tqdm
import math
from torch.utils.data import TensorDataset, DataLoader
from model_giff import DS_CNN_FF
from torchinfo import summary
import datetime

class Opts:
    parser = argparse.ArgumentParser(description='forward-forward-temporal training args')
    parser.add_argument('--device', default='cuda:0', help='which cuda device to use')
    parser.add_argument('--lr', type=float, default=0.005, help='learning rate')
    parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train (default: 50)')
    parser.add_argument('--batch_size', type=int, default=200, help='input batch size for training (default: 200)')
    parser.add_argument('--theta', type=float, default=1, help='layer loss param')
    parser.add_argument('--online_visual', type=int, default=1, help='enable wandb')
    parser.add_argument('--api_key', default='4b12acacf91e1245e461234', help='wandb api key default is a wrong one')
    parser.add_argument('--project_name', default='forward-forward-benchmark-DSCNN_GIFF_Fixed', help='wandb project name')
    parser.add_argument('--weight_decay', type=float, default=5e-3, help='weight decay')
    parser.add_argument('--label_ext', type=int, default=1, help='label extension width for embedding')
    parser.add_argument('--adaptive_lr', type=int, default=0, help='enable adaptive learning rate')
    parser.add_argument('--start_lr', type=float, default=0.001, help='start learning rate for warmup')
    parser.add_argument('--warmup_epochs', type=int, default=10, help='warmup epochs')
    args = parser.parse_args()
    batch_size = args.batch_size
    theta = args.theta
    lr = args.lr    
    weight_decay = args.weight_decay
    epochs = args.epochs
    device = args.device
    label_ext = args.label_ext
    online_visual = args.online_visual
    api_key = args.api_key
    project_name = args.project_name
    adaptive_lr = args.adaptive_lr
    start_lr = args.start_lr
    warmup_epochs = args.warmup_epochs
    log_lr = lr
    runtime_name = "settings"

def load_data(file_path):
    # Load the data from the .npy files using numpy
    ds = np.load(file_path,allow_pickle=True)
    data = []
    label = []
    for i in ds["data"][0]:
        data.append(i)
    data = np.array(data)
    for j in ds["data"][1]:
        label.append(j)
    label = np.array(label)
    train_tensor = torch.tensor(data, dtype=torch.float32).permute(0, 3, 1, 2)
    label_tensor = torch.tensor(label, dtype=torch.long)
    
    # Create a PyTorch TensorDataset
    dataset = TensorDataset(train_tensor, label_tensor)
    
    return dataset

def cosine_annealing_lr_with_warmup(epoch, num_epochs, initial_lr, warmup_epochs=10, start_lr=0.0001, adaptive=1):
    """
    Cosine annealing learning rate schedule with warm-up.
    """
    max_epochs = num_epochs
    if epoch < warmup_epochs:
        # Linearly increase the learning rate
        lr = start_lr + ((initial_lr - start_lr) / warmup_epochs) * epoch
    else:
        # Shifted epoch value to account for the warm-up period
        shifted_epoch = epoch - warmup_epochs
        max_epochs -= warmup_epochs
        # Cosine annealing after the warm-up period
        if adaptive == 1:
            lr = 0.5 * initial_lr * (1 + math.cos(math.pi * shifted_epoch / max_epochs))
        else:
            lr = initial_lr/math.sqrt(shifted_epoch+1)

    return lr

def load_model(config=None):
    model = DS_CNN_FF(config)
    return model

@torch.no_grad()
def test(network, test_loader, opts):
    all_outputs = []
    all_labels = []

    for (x_test, y_test) in test_loader:
        x_test, y_test = x_test.float().to(opts.device), y_test.float().to(opts.device)

        acts_for_labels = []

        #################################
        # Forward Pass
        for label in range(12):
            test_label = torch.ones_like(y_test.argmax(dim=-1)).fill_(label)
            test_label = F.one_hot(test_label, num_classes=12).float()
            test_label = test_label.repeat(x_test.shape[0], 1)
            acts = network(x_test, test_label)
            goodness = acts.pow(2).sum(dim= -1)
            acts_for_labels.append(goodness)
        acts_for_labels = torch.stack(acts_for_labels, dim=1)
        all_outputs.append(acts_for_labels)
        all_labels.append(y_test)
        
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    correct = all_outputs.argmin(dim=-1).eq(all_labels).sum().item()
    return correct/len(all_labels)

def train(network, optimizer, train_loader, valid_loader, opts):
    running_loss = 0.
    square_sums_pos = []
    square_sums_neg = []
    for batch_no, (x, y_ground) in tqdm(enumerate(train_loader), total=len(train_loader), desc="Training"):
        x, y_ground = x.to(opts.device), y_ground.to(opts.device)
        y_ground_ce = F.one_hot(y_ground, num_classes=12)
        batch_size = y_ground.shape[0]  
        x_ce = x
        # generate random one-hot label that is different from y_ground
        random_ints = torch.randint(0,12,(batch_size,)).to(opts.device)
        y_random = random_ints
        y_random_ce = F.one_hot(y_random, num_classes=12)   
        y_ce_zeros = torch.zeros_like(y_ground_ce)   
        x_neg_ce = x
        
        # ----- FF pass ----- #
        posit_sum = 0
        negat_sum = 0
        
        y_ground_ce = y_ground_ce.float()
        y_random_ce = y_random_ce.float()
    
        for layer_idx, layer in enumerate(network.blocks.children()):

            with torch.enable_grad():
                z_pos, x_pos = layer(x_ce, y_ground_ce)
                z_neg, x_neg = layer(x_neg_ce, y_random_ce)
                
                zp_square_mean = z_pos.pow(2).sum(dim=[1,2,3])
                zn_square_mean = z_neg.pow(2).sum(dim=[1,2,3])
    
                positive_loss = torch.log(1 + torch.exp((zp_square_mean - opts.theta))).mean()
                #positive_loss = nn.Softplus(beta=1,threshold=20)(zp_square_mean - opts.theta).mean()
                negative_loss = torch.log(1 + torch.exp((-zn_square_mean + opts.theta))).mean()
                #negative_loss = nn.Softplus(beta=1,threshold=20)((-zn_square_mean + opts.theta)).mean()
                
                actor_loss = positive_loss + negative_loss
                actor_loss.backward()
                
                running_loss += actor_loss.detach()
                optimizer[layer_idx].step()
                optimizer[layer_idx].zero_grad()  
                #posit_sum += zp_square_mean.mean().item()
                #negat_sum += zn_square_mean.mean().item()
            x_ce = x_pos.detach()
            x_neg_ce = x_neg.detach()
    #square_sums_pos.append(posit_sum)
    #square_sums_neg.append(negat_sum)
    #print("----square_sums----", np.mean(square_sums_pos), np.mean(square_sums_neg))
    #train_acc = test(network, train_loader, opts)
    # train_acc = 0
    # valid_acc = 0
    # train_acc = test(network, train_loader, opts)
    # valid_acc = test(network, valid_loader, opts)
    # running_loss = running_loss/len(train_loader)
        
    # if opts.online_visual == 1:
    #     wandb.log({"train_loss": running_loss, "train_acc": train_acc, "valid_acc": valid_acc})
    # return running_loss, train_acc, valid_acc
    return running_loss


def run_dscnn_giff(opts,config):
    
    # Load model
    model = load_model(config=config).to(opts.device)
    summary(model)
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters())
    num_params = count_parameters(model)
    peak_mem = 0
    mem_dist = []
    for block in model.blocks.children():
        num_activation = block.activation_num
        num_grad = block.gradient_num
        num_err = block.error_num
        current_mem = (num_params + num_activation + num_grad + num_err)*4/1024  #Assumed float32
        mem_dist.append(current_mem)
        if current_mem > peak_mem:
            peak_mem = current_mem
    print("~~~~~estimated peak mem~~~~~")
    print(">>>:", peak_mem, "KB")
    print(mem_dist)
    # Load the dataset
    train_dataset = load_data("./data/mfcc_train_data.npz")
    val_dataset = load_data("./data/mfcc_val_data.npz")
    test_dataset = load_data("./data/mfcc_test_data.npz")
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=4)
    # Initialize wandb
    if opts.online_visual == 1:
        os.environ["WANDB_API_KEY"] = opts.api_key 
        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')   
        wandb_prjname = opts.project_name
        wandb_runname = opts.runtime_name+"_peak_mem_"+str(peak_mem)+"_time_"+str(current_time)
        wandb.init(project=wandb_prjname, name=wandb_runname)

    optimizers = [torch.optim.AdamW(block.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
            for block in model.blocks.children()
            ]   
    
    best_acc = 0
    for step in range(0, opts.epochs):

        lr = cosine_annealing_lr_with_warmup(step, opts.epochs, opts.lr, opts.warmup_epochs, opts.start_lr, opts.adaptive_lr)
        opts.log_lr = lr
        for opt in optimizers:
            for param_group in opt.param_groups:
                param_group['lr'] = lr
        
        train_loss = train(model, optimizers,train_loader,val_loader,opts)

        test_acc = 0

        if step%1 == 0: 
            valid_acc = test(model, val_loader, opts)
            test_acc = test(model, test_loader, opts)
            print(f"Step {step:04d} train_loss_FF: {train_loss:.4f} \
                    test_acc_FF: {test_acc:.4f} lr: {opts.log_lr:.6f}")
            # if valid_acc > best_acc:
            #     best_acc = valid_acc
            #     print("Best accuracy so far! ---> {:d}".format(round(best_acc*100)))
            #     print("model saved!!!")
            #     torch.save(model.state_dict(), 'models/GIFF_DSCNN-fixed_InferenceTest.pth')
        if opts.online_visual == 1:
            wandb.log({"train_loss": train_loss, "test_acc": test_acc, "valid_acc": valid_acc, "lr": opts.log_lr})

    wandb.finish()
    print('TRAINING COMPLETE!')
    print('-'*50)

