import sys
from tabnanny import verbose
import os
import datetime
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
from utilities import *
import time
import torch
from torch import nn
import numpy as np
import pickle
from torch.cuda.amp import autocast,GradScaler
import logging
import pickle
from tqdm import tqdm
import torch
from utilities.new_map import *
from utilities.ontology_loss import *
import wandb

def logging_info(rank, msg):
    if(rank == 0):
        logging.info(msg)

DATA = None

def train(rank, n_gpus, audio_model, train_loader, test_loader, args):
    # if n_gpus > 1:
    #     train_loader.batch_sampler.set_epoch(epoch)
    device = torch.device("cuda:%s" % rank if torch.cuda.is_available() else "cpu")
    torch.set_grad_enabled(True)

    # Initialize all of the statistics we want to keep track of
    batch_time = AverageMeter()
    per_sample_time = AverageMeter()
    data_time = AverageMeter()
    per_sample_data_time = AverageMeter()
    loss_meter = AverageMeter()
    score_loss_meter = AverageMeter()
    energy_meter = AverageMeter()
    zero_loss_meter = AverageMeter()
    per_sample_dnn_time = AverageMeter()
    progress = []
    # best_ensemble_mAP is checkpoint ensemble from the first epoch to the best epoch
    best_epoch, best_ensemble_epoch, best_mAP, best_acc, best_ensemble_mAP = 0, 0, -np.inf, -np.inf, -np.inf
    global_step, epoch = 0, 0
    start_time = time.time()
    exp_dir = args.exp_dir
    
    def _save_progress():
        progress.append([epoch, global_step, best_epoch, best_mAP, time.time() - start_time])
        with open("%s/progress.pkl" % exp_dir, "wb") as f:
            pickle.dump(progress, f)

    if(os.path.exists(os.path.join(args.exp_dir, "models/best_audio_model.pth"))):
        logging_info(rank, "Reloading model params" + os.path.join(args.exp_dir, "models/best_audio_model.pth"))
        model_checkpoint = torch.load(os.path.join(args.exp_dir, "models/best_audio_model.pth"), map_location="cpu")
        audio_model.load_state_dict(model_checkpoint["state_dict"])
        epoch = model_checkpoint["epoch"]
        global_step = model_checkpoint["global_step"]
        args.warmup=False
    

    audio_model = audio_model.to(device)
    # Set up the optimizer
    trainables = [p for p in audio_model.parameters() if p.requires_grad]

    # trainables = [p for p in audio_model.module.effnet.parameters() if p.requires_grad] + [p for p in audio_model.module.attention.parameters() if p.requires_grad]
    # trainable_frontend = [p for p in audio_model.module.neural_sampler.parameters() if p.requires_grad]

    logging_info(rank, 'Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in audio_model.parameters()) / 1e6))
    logging_info(rank, 'Total trainable parameter number is : {:.3f} million'.format(sum(p.numel() for p in trainables) / 1e6))
    optimizer = torch.optim.Adam(trainables, args.lr, weight_decay=5e-7, betas=(0.95, 0.999))
    
    if(os.path.exists(os.path.join(args.exp_dir, "models/best_optim_state.pth"))):
        logging_info(rank, "Reloading optimizer" + os.path.join(args.exp_dir, "models/best_optim_state.pth"))
        opt_checkpoint = torch.load(os.path.join(args.exp_dir, "models/best_optim_state.pth"), map_location="cpu")
        optimizer.load_state_dict(opt_checkpoint["state_dict"])
        epoch = model_checkpoint["epoch"]
        global_step = model_checkpoint["global_step"]
        args.warmup=False
        
    # dataset specific settings
    if(args.dataset == "speechcommands" or "esc" in args.dataset):
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args.lrscheduler_start, 1000, 1)), gamma=args.lrscheduler_decay, last_epoch=epoch - 1)
    elif("nsynth" in args.dataset):
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args.lrscheduler_start, 1000, 1)), gamma=args.lrscheduler_decay, last_epoch=epoch - 1)
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args.lrscheduler_start, 1000, 5)), gamma=args.lrscheduler_decay, last_epoch=epoch - 1)
        
    main_metrics = args.metrics
    if args.loss == 'BCE':
        loss_fn = nn.BCELoss(reduction="none")
    elif args.loss == 'CE':
        loss_fn = nn.CrossEntropyLoss()
    warmup = args.warmup
    args.loss_fn = loss_fn
    logging_info(rank, 'now training with {:s}, main metrics: {:s}, loss function: {:s}, learning rate scheduler: {:s}'.format(str(args.dataset), str(main_metrics), str(loss_fn), str(scheduler)))
    logging_info(rank, 'The learning rate scheduler starts at {:d} epoch with decay rate of {:.3f} '.format(args.lrscheduler_start, args.lrscheduler_decay))

    epoch += 1

    logging_info(rank, "current #steps=%s, #epochs=%s" % (global_step, epoch))
    logging_info(rank, "start training...")
    result = np.zeros([args.n_epochs, 10])
    audio_model.train()
    while epoch < args.n_epochs + 1:
        print("Epoch:", epoch)
        begin_time = time.time()
        end_time = time.time()
        audio_model.train()
        logging_info(rank, '---------------')
        logging_info(rank, datetime.datetime.now())
        logging_info(rank, "current #epochs=%s, #steps=%s" % (epoch, global_step))
        # print(os.getpid(), "ready to engage")
        for i, (audio_input, labels, fnames) in enumerate(train_loader):
            # if(i % 10 == 0):
            #     import matplotlib.pyplot as plt
            #     for j in tqdm(range(audio_input.size(0))):
            #         if(j>5): break
            #         plt.imshow(audio_input[j].detach().cpu().numpy())
            #         plt.title(str(torch.where(labels[j] > 0)) + fnames[j])
            #         plt.savefig("%s_%s.png" % (j, datetime.datetime.now()))
            
            # res = []
            # for j in range(audio_input.size(0)):
            #     target = torch.where(labels[j] > 0)[0].item()
            #     res.append(target)
            # print(set(res))
            
            B = audio_input.size(0)
            audio_input = audio_input.to(device, non_blocking=True)
            # waveform = waveform.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            # If you want to measure the mean and std of the dataset
            # global DATA
            # if(DATA is None):
            #     DATA = audio_input.flatten()
            # else:
            #     DATA = torch.cat([DATA, audio_input.flatten()])
            # print(torch.mean(DATA), torch.std(DATA))
            
            data_time.update(time.time() - end_time)
            per_sample_data_time.update((time.time() - end_time) / audio_input.shape[0])
            dnn_start_time = time.time()

            # first several steps for warm-up
            if global_step <= 1000 and global_step % 50 == 0 and warmup == True:
                # optimizer.param_groups[0]['lr'] = (global_step / 1000) * args.lr
                # optimizer.param_groups[1]['lr'] = (global_step / 1000) * args.lr
                warm_lr = (global_step / 1000) * args.lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = warm_lr
                # logging_info(rank, 'warm-up learning rate is {:f} {:f}'.format(optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr']))
                logging_info(rank, 'warm-up learning rate is {:f}'.format(optimizer.param_groups[0]['lr']))

            audio_output, score_pred, energy_score = audio_model(audio_input)

            # if(torch.sum(torch.isnan(audio_output) >= 1) or torch.sum(torch.isinf(audio_output) >= 1)):
            #     import ipdb; ipdb.set_trace()
                
            if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
                loss = loss_fn(audio_output, torch.argmax(labels.long(), axis=1))
            else:
                epsilon = 1e-7
                audio_output = torch.clamp(audio_output, epsilon, 1. - epsilon)
                loss = loss_fn(audio_output, labels)
                if(args.reweight_loss and args.non_weighted_loss):
                    loss_weight = eval(args.weight_func)(labels, args.graph_weight_path, beta=args.beta)
                    loss = (torch.mean(loss * loss_weight) + torch.mean(loss)) / 2
                elif(args.reweight_loss and not args.non_weighted_loss):
                    loss_weight = eval(args.weight_func)(labels, args.graph_weight_path, beta=args.beta)
                    loss = torch.mean(loss * loss_weight)
                else:
                    loss = torch.mean(loss)
            # Can this work?
            # Ignore empty frames
            score_mask = torch.mean((audio_input * args.dataset_std + args.dataset_mean).exp(), dim=-1, keepdim=True)
            score_mask = score_mask < (torch.min(score_mask) + 1e-6)
            zero_loss_final = None
            std_loss_final = torch.tensor([0.0]).cuda(rank, non_blocking=True)
            energy_final = torch.tensor([0.0]).cuda(rank, non_blocking=True)
            # std_loss_final = None
            # energy_final = None
            for id in range(score_pred.size(0)):
                zero_loss = torch.mean(score_pred[id][score_mask[id]])
                if(torch.isnan(zero_loss).item()):
                    continue
                # if(args.preserve_ratio < 1.0):
                if(zero_loss > args.apply_zero_loss_threshold * args.preserve_ratio):
                    loss = loss + args.lambda_zero_loss * zero_loss / score_pred.size(0) # [bs, length, 1]
                if(zero_loss_final is None):
                    zero_loss_final = zero_loss / score_pred.size(0)
                else:
                    zero_loss_final = zero_loss_final + zero_loss / score_pred.size(0)

                ##################################################################
                energy_loss = torch.std(energy_score[id][~score_mask[id]])
                if(torch.isnan(energy_loss).item()): continue
                if(energy_final is None):
                    energy_final = energy_loss / energy_score.size(0)
                else:
                    energy_final = energy_final + energy_loss / energy_score.size(0)
                
                ##################################################################                
                std_loss = torch.std(score_pred[id][~score_mask[id]])
                if(torch.isnan(std_loss).item()): continue
                # if(std_loss < energy_loss * 0.5):
                    # loss = loss - std_loss / score_loss.size(0) # [bs, length, 1]
                if(std_loss_final is None):
                    std_loss_final = std_loss / score_pred.size(0)
                else:
                    std_loss_final = std_loss_final + std_loss / score_pred.size(0)
            # print(loss)
            # optimization if amp is not used
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # record loss
            loss_meter.update(loss.item(), B)
            energy_meter.update(energy_final.item(), B)
            score_loss_meter.update(std_loss_final.item(), B)
            zero_loss_meter.update(zero_loss_final.item(), B)
            batch_time.update(time.time() - end_time)
            per_sample_time.update((time.time() - end_time)/audio_input.shape[0])
            per_sample_dnn_time.update((time.time() - dnn_start_time)/audio_input.shape[0])

            print_step = global_step % args.n_print_steps == 0
            early_print_step = epoch == 0 and global_step % (args.n_print_steps/10) == 0
            print_step = print_step or early_print_step

            end_time = time.time()
            global_step += 1
            
            if(rank == 0):
                if print_step and global_step != 0:
                    info = {
                    "train-loss":float(loss_meter.avg),
                    "score-loss":float(score_loss_meter.avg),
                    "zero-loss":float(zero_loss_meter.avg),
                    "speed-data":float(per_sample_data_time.avg),
                    "speed-total":float(per_sample_time.avg),
                    "speed-dnn":float(per_sample_dnn_time.avg)}

                    wandb.log(info, step=global_step)
                    
                    msg='Epoch: [{0}][{1}/{2}]\t Per Sample Total Time {per_sample_time.avg:.5f}\t Per Sample Data Time {per_sample_data_time.avg:.5f}\t Per Sample DNN Time {per_sample_dnn_time.avg:.5f}\t Train Loss {loss_meter.avg:.4f}\t std Loss {score_loss_meter.avg:.4f}\t zero Loss {zero_loss_meter.avg:.4f}\t energy Loss {energy_meter.avg:.4f}\t'.format(epoch, i, len(train_loader), per_sample_time=per_sample_time, per_sample_data_time=per_sample_data_time,per_sample_dnn_time=per_sample_dnn_time, loss_meter=loss_meter, score_loss_meter=score_loss_meter, zero_loss_meter=zero_loss_meter,energy_meter=energy_meter)
                    logging_info(rank, msg)
                    print(msg)
                    if np.isnan(loss_meter.avg):
                        logging.error("training diverged...")
                        return
                       
        if(rank == 0):
            if(epoch % args.val_interval == 0):
                logging_info(rank, 'start validation')
                print('start validation')
                stats, valid_loss = validate(rank, n_gpus, audio_model, test_loader, args, epoch)

                mAP = np.mean([stat['AP'] for stat in stats])
                mAUC = np.mean([stat['auc'] for stat in stats])
                
                acc = stats[0]['acc']
                fps_ap = stats[0]['fps_ap'] # Use the average of the fps_curve
                fps_ap_mm = stats[0]['fps_ap_mm'] # Use the matmul distance based weight matrix
                fps_curve = stats[0]['fps_curve']
                ontology_ap = stats[0]['fps_curve'][0] # Use the min distance based weight matrix. The ontology based metric
                draw_fps_curve(epoch, fps_curve, args.exp_dir)
                
                logging_info(rank, "mAP %s, mAUC %s, acc %s, fps_curve_average %s, ontology_ap %s" % (mAP, mAUC, acc, fps_ap, ontology_ap))
                
                val_info = {"val-mAP": mAP, "val-mAUC": mAUC, "val-acc":acc, "fps_curve_average": fps_ap, "fps_ap_mm": fps_ap_mm, "ontology_ap": ontology_ap}
                
                for k in val_info.keys():
                    print(k, val_info[k])
                
                wandb.log(val_info, step=global_step)
                
                # ensemble results
                ensemble_stats = validate_ensemble(rank, n_gpus, args, epoch)
                ensemble_mAP = np.mean([stat['AP'] for stat in ensemble_stats])
                ensemble_mAUC = np.mean([stat['auc'] for stat in ensemble_stats])
                ensemble_acc = ensemble_stats[0]['acc']

                middle_ps = [stat['precisions'][int(len(stat['precisions'])/2)] for stat in stats]
                middle_rs = [stat['recalls'][int(len(stat['recalls'])/2)] for stat in stats]
                average_precision = np.mean(middle_ps)
                average_recall = np.mean(middle_rs)

                if main_metrics == 'mAP':
                    logging_info(rank, "mAP: {:.6f}".format(mAP))
                else:
                    logging_info(rank, "acc: {:.6f}".format(acc))
                logging_info(rank, "AUC: {:.6f}".format(mAUC))
                logging_info(rank, "Avg Precision: {:.6f}".format(average_precision))
                logging_info(rank, "Avg Recall: {:.6f}".format(average_recall))
                logging_info(rank, "d_prime: {:.6f}".format(d_prime(mAUC)))
                logging_info(rank, "train_loss: {:.6f}".format(loss_meter.avg))
                logging_info(rank, "valid_loss: {:.6f}".format(valid_loss))

                # if(args.val_interval == 1):
                if main_metrics == 'mAP':
                    result[epoch-1, :] = [mAP, mAUC, average_precision, average_recall, d_prime(mAUC), loss_meter.avg, valid_loss, ensemble_mAP, ensemble_mAUC, optimizer.param_groups[0]['lr']]
                else:
                    result[epoch-1, :] = [acc, mAUC, average_precision, average_recall, d_prime(mAUC), loss_meter.avg, valid_loss, ensemble_acc, ensemble_mAUC, optimizer.param_groups[0]['lr']]
                np.savetxt(exp_dir + '/result.csv', result, delimiter=',')
                logging_info(rank, 'validation finished')

                if mAP > best_mAP:
                    best_mAP = mAP
                    if main_metrics == 'mAP':
                        best_epoch = epoch

                if acc > best_acc:
                    best_acc = acc
                    if main_metrics == 'acc':
                        best_epoch = epoch

                # if(args.val_interval == 1):
                if ensemble_mAP > best_ensemble_mAP:
                    best_ensemble_epoch = epoch
                    best_ensemble_mAP = ensemble_mAP

                if best_epoch == epoch:
                    torch.save({"state_dict": audio_model.state_dict(), "epoch": best_epoch, "global_step": global_step}, "%s/models/best_audio_model.pth" % (exp_dir))
                    torch.save({"state_dict": optimizer.state_dict(), "epoch": best_epoch, "global_step": global_step}, "%s/models/best_optim_state.pth" % (exp_dir))

                torch.save(audio_model.state_dict(), "%s/models/audio_model.%d.pth" % (exp_dir, epoch))
                # if len(train_loader.dataset) > 2e5:
                torch.save(optimizer.state_dict(), "%s/models/optim_state.%d.pth" % (exp_dir, epoch))

                with open(exp_dir + '/stats_' + str(epoch) +'.pickle', 'wb') as handle:
                    pickle.dump(stats, handle, protocol=pickle.HIGHEST_PROTOCOL)
                _save_progress()
                
            # logging_info(rank, 'Epoch-{0} lr: {1}, {2}'.format(epoch, optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr']))
            logging_info(rank, 'Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr']))

            finish_time = time.time()
            logging_info(rank, 'epoch {:d} training time: {:.3f}'.format(epoch, finish_time-begin_time))
                
        scheduler.step()

        epoch += 1

        batch_time.reset()
        per_sample_time.reset()
        data_time.reset()
        per_sample_data_time.reset()
        loss_meter.reset()
        per_sample_dnn_time.reset()

    if(rank == 0):
        # if test weight averaging
        if args.wa == True:
            stats=validate_wa(rank, n_gpus, audio_model, test_loader, args, args.wa_start, args.wa_end)
            mAP = np.mean([stat['AP'] for stat in stats])
            mAUC = np.mean([stat['auc'] for stat in stats])
            middle_ps = [stat['precisions'][int(len(stat['precisions'])/2)] for stat in stats]
            middle_rs = [stat['recalls'][int(len(stat['recalls'])/2)] for stat in stats]
            average_precision = np.mean(middle_ps)
            average_recall = np.mean(middle_rs)
            wa_result = [mAP, mAUC]
            logging_info(rank, '---------------Training Finished---------------')
            # logging_info(rank, 'On Validation Set')
            # logging_info(rank, 'weighted averaged model results')
            # logging_info(rank, "mAP: {:.6f}".format(mAP))
            # logging_info(rank, "AUC: {:.6f}".format(mAUC))
            # logging_info(rank, "d_prime: {:.6f}".format(d_prime(mAUC)))
            np.savetxt(exp_dir + '/wa_result.csv', wa_result)

def draw_fps_curve(epoch, fps_curve, exp_dir):
    import matplotlib.pyplot as plt
    plt.plot(fps_curve)
    plt.ylim([0.0,1.0])
    plt.savefig(os.path.join(exp_dir, "fps_curve_%s.png" % epoch))
    plt.close()

def validate(rank, n_gpus, audio_model, val_loader, args, epoch, eval_target=False):  
    device = torch.device("cuda:%s" % rank if torch.cuda.is_available() else "cpu")
    batch_time = AverageMeter()
    # if not isinstance(audio_model, nn.DataParallel):    
    #     audio_model = nn.DataParallel(audio_model)  
    audio_model = audio_model.to(device)    
    # switch to evaluate mode   
    audio_model.eval()  
    end = time.time()   
    A_predictions = []  
    A_targets = []  
    A_loss = [] 
    A_fname = []
    with torch.no_grad():   
        for i, (audio_input, labels, fname) in tqdm(enumerate(val_loader)):  
            batchsize = audio_input.size(0)
            # waveform = waveform.to(device)
            audio_input = audio_input.to(device)  
            # compute output    
            audio_output,_,_ = audio_model(audio_input) 
            
            predictions = audio_output.to('cpu').detach()   
            A_predictions.append(predictions)   
            A_targets.append(labels)    
            A_fname.append(fname)
            # compute the loss  
            labels = labels.to(device)
            epsilon = 1e-7
            audio_output = torch.clamp(audio_output, epsilon, 1. - epsilon) 
            if isinstance(args.loss_fn, torch.nn.CrossEntropyLoss): 
                loss = torch.mean(args.loss_fn(audio_output, torch.argmax(labels.long(), axis=1)))
            else:   
                loss = torch.mean(args.loss_fn(audio_output, labels))   
            A_loss.append(loss.to('cpu').detach())  
            batch_time.update(time.time() - end)    
            end = time.time()   
        audio_output = torch.cat(A_predictions) 
        fname = np.array(A_fname)
        target = torch.cat(A_targets)   
        loss = np.mean(A_loss)  

        print("save the model prediction pickle file")
        output_dict = {}
        output_dict["audio_name"] = fname[:,0]
        output_dict["clipwise_output"] = audio_output.cpu().numpy()
        output_dict["target"] = target.cpu().numpy()
        path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
        save_pickle(output_dict, os.path.join(path, "%s_%s.pkl" % (args.sampler, epoch)))

        stats = calculate_stats(audio_output.cpu().numpy(), target.cpu().numpy(), args)   
        # save the prediction here  
        exp_dir = args.exp_dir  
        if os.path.exists(exp_dir+'/predictions') == False: 
            os.mkdir(exp_dir+'/predictions')    
            np.savetxt(exp_dir+'/predictions/target.csv', target, delimiter=',')    
        np.savetxt(exp_dir+'/predictions/predictions_' + str(epoch) + '.csv', audio_output, delimiter=',')  
        # save the target for the separate eval set if there's one. 
        if eval_target == True and os.path.exists(exp_dir+'/predictions/eval_target.csv') == False: 
            np.savetxt(exp_dir + '/predictions/eval_target.csv', target, delimiter=',') 
    return stats, loss

def validate_ensemble(rank, n_gpus, args, epoch):
    exp_dir = args.exp_dir
    target = np.loadtxt(exp_dir+'/predictions/target.csv', delimiter=',')
    if epoch == args.val_interval or not os.path.exists(exp_dir+'/predictions/ensemble_predictions.csv'):
        ensemble_predictions = np.loadtxt(exp_dir + '/predictions/predictions_%s.csv' % epoch, delimiter=',')
    else:
        ensemble_predictions = np.loadtxt(exp_dir + '/predictions/ensemble_predictions.csv', delimiter=',') * (epoch - args.val_interval)
        predictions = np.loadtxt(exp_dir+'/predictions/predictions_' + str(epoch) + '.csv', delimiter=',')
        ensemble_predictions = ensemble_predictions + predictions
        # remove the prediction file to save storage space
        if(os.path.exists(exp_dir+'/predictions/predictions_' + str(epoch - args.val_interval) + '.csv')):
            os.remove(exp_dir+'/predictions/predictions_' + str(epoch - args.val_interval) + '.csv')

    ensemble_predictions = ensemble_predictions / epoch
    np.savetxt(exp_dir+'/predictions/ensemble_predictions.csv', ensemble_predictions, delimiter=',')

    stats = calculate_stats(ensemble_predictions, target, args)
    return stats

def validate_wa(rank, n_gpus, audio_model, val_loader, args, start_epoch, end_epoch):
    device = torch.device("cuda:%s" % rank if torch.cuda.is_available() else "cpu")
    exp_dir = args.exp_dir

    # sdA = torch.load(exp_dir + '/models/audio_model.' + str(args.val_interval) + '.pth', map_location=device)
    sdA = None
    model_cnt = 1
    for epoch in range(start_epoch, end_epoch+1):
        if(not os.path.exists(exp_dir + '/models/audio_model.' + str(epoch) + '.pth')): 
            continue
        sdB = torch.load(exp_dir + '/models/audio_model.' + str(epoch) + '.pth', map_location=device)
        if(sdA is None):
            sdA = sdB; continue
        for key in sdA:
            sdA[key] = sdA[key] + sdB[key]
        model_cnt += 1

        # if choose not to save models of epoch, remove to save space
        if args.save_model == False:
            os.remove(exp_dir + '/models/audio_model.' + str(epoch) + '.pth')

    # averaging
    for key in sdA:
        sdA[key] = sdA[key] / float(model_cnt)

    audio_model.load_state_dict(sdA)

    torch.save({"state_dict": audio_model.state_dict(), "epoch": -1, "global_step": -1}, exp_dir + '/models/audio_model_wa.pth')

    stats, loss = validate(rank, n_gpus, audio_model, val_loader, args, 'wa')
    return stats