import os, time
from argparse import Namespace
from functions import *
import ctypes
import glob
import numpy as np 
import torch

def vec_triu(mat):
    bs, d = mat.shape
    dij = torch.bmm(mat.unsqueeze(2), mat.unsqueeze(1))
    vec_dij = torch.zeros(bs, int(d*(d+1)/2)).to(mat.device)
    triu_idx = np.triu_indices(d, 0)

    for i in range(bs):
        vec_dij[i] = dij[i][triu_idx]
    return vec_dij

def trainBLMNNmodel(trX, trY, teX, teY, mu_0, v_0, mu_t, LPN, I_triuA, option, checkpointer, flogger):
    # function model = trainBLMNNmodel(trX,trY,mu_0,v_0,mu_t,LPN,I_triuA,option)
    model = Namespace()
    use_cuda = option.use_cuda
    KNN_tr = option.KNN
    batchsize = option.batchsize
    maxIter = option.maxIter
    stepSize = option.stepSize
    data_scalar = option.data_scalar
    lambda_ = option.lambda_
    flag = option.flag
    dim = trX.shape[1]
    numClasses = len(set(trY))

    total_runtime = 0
    lp1_all, lp2_all, ln_all = LPN
    train_accs, test_accs = [], []
    
    for iter in range(1, maxIter+1):
        tic = time.time()
        iter_k = 1
        rho = stepSize[iter-1]
        rpn = [i for i in range(len(ln_all))] # rand does not help
        batchnum = int(np.ceil(len(ln_all)/batchsize))

        if option.use_intraclass_pairs and numClasses < 100:
            label_index = np.array([i for i in range(1,numClasses+1)]) 
            if option.lb_start_index == 0: 
                label_index = label_index - 1
            [RSlp1_all, RSlp2_all, RSln_all] = getTripletIndex(trY, label_index, KNN_tr)
            RS_size = int(np.floor(len(RSln_all)/batchnum))
        
        # Batch SVI
        for lpb in range(1,batchnum+1,1):
            t1 = time.time()
            # randomly sampling triplets
            ibegin = (lpb-1)*batchsize
            iend = lpb*batchsize
            # if (iend > len(ln_all)):
            #     iend = len(ln_all)
            
            lp1_t = lp1_all[rpn[ibegin:iend]]
            lp2_t = lp2_all[rpn[ibegin:iend]]
            ln_t = ln_all[rpn[ibegin:iend]]
            
            pX1 =  trX[lp1_t,:]
            pX2 =  trX[lp2_t,:]
            nX =  trX[ln_t,:]
            p_cxt =  pX1 -  pX2 # xi - xj: (bs, D)
            n_cxt = pX1 - nX # xi - xl: (bs, D)
            curr_batchsize, featureDim = p_cxt.shape

            ## Move to cuda for faster calculation
            p_cxt, n_cxt = torch.tensor(p_cxt, dtype=torch.float32) , torch.tensor(n_cxt, dtype=torch.float32)
            mu_t = torch.tensor(mu_t, dtype=torch.float32)
            mu_0 = torch.tensor(mu_0, dtype=torch.float32)
            v_0 = torch.tensor(v_0, dtype=torch.float32)
            if use_cuda:
                p_cxt, n_cxt = p_cxt.cuda(), n_cxt.cuda()
                mu_t, mu_0, v_0 = mu_t.cuda(), mu_0.cuda(), v_0.cuda()

            # compute: xij = vectorize_triu((xi - xj)(xi - xj)^T): (bs, D(D+1)/2)
            pxt = vec_triu(p_cxt).T
            nxt = vec_triu(n_cxt).T
            # print("sum true: ", np.sum(p_cxt))
            WWt = -pxt + nxt; # minus version: xil - xij 
            #(D(D+1)/2, bs)

            # Updating
            RS_pxt = pxt

            del pxt, nxt
            if option.use_intraclass_pairs and numClasses < 100: # using random intraclass pairs for small data
                RS_ibegin = (lpb-1)*RS_size
                RS_iend = lpb*RS_size
                if(RS_iend>len(RSln_all)):
                    RS_iend = len(RSln_all)
                
                RS_pxt =  trX[RSlp1_all[RS_ibegin:RS_iend],:] - trX[RSlp2_all[RS_ibegin:RS_iend],:] 
                # RS_pxt = singleoutproduct(RS_pxt.T)
                RS_pxt = torch.tensor(RS_pxt)
                if use_cuda: 
                    RS_pxt = RS_pxt.cuda()
                RS_pxt = vec_triu(RS_pxt).T

            # L is \lambda bar ijl
            # mu_t is γ¯t−1, WWt is x_ijl
            # if use_cuda: 
            #     WWt = WWt.cpu().numpy()
            #     RS_pxt = RS_pxt.cpu().numpy()
            L = getLambda_(mu_t, WWt*data_scalar) 

            mu_t = getPosterior(L, WWt*data_scalar, RS_pxt*data_scalar, mu_0, 
                                    v_0, mu_t, lambda_, rho, flag)
            
            if use_cuda: 
                mu_t = mu_t.cpu().numpy()
                mu_0, v_0 = mu_0.cpu().numpy(), v_0.cpu().numpy()


            t2 = time.time() - t1
            dis = np.sqrt(sum(mu_t-mu_0)**2)*1/max(mu_0)
            mess = 'Iteration: %d in %d || step: %d in %d || dis: %.2f || time(s): %.1f'\
                    %(iter, maxIter, iter_k, batchnum, dis, t2)
            print('\r' + mess, end=''), flogger.info(mess)
            iter_k = iter_k +1
            total_runtime = total_runtime + t2
        toc = time.time() - tic
        acc_tr = evaluate(mu_t, trX, trX, trY, trY, I_triuA, option)
        acc_te = evaluate(mu_t, trX, teX, trY, teY, I_triuA, option)
        mess = 'Iteration: %d in %d, performance: Train: %.2f, Test: %.2f, time(s): %.1f\n'%(iter, maxIter, acc_tr*100, acc_te*100, toc)
        print(mess), flogger.info(mess)
        train_accs.append(acc_tr), test_accs.append(acc_te)

        early_stop = checkpointer.step(acc_te)
        if early_stop:
            mess = "Early Stopped!"
            print(mess), flogger.info(mess)
            break

    model.mu = mu_t
    model.trainingTime = total_runtime
    return model, train_accs, test_accs
