import torch
import numpy as np 
import argparse
from argparse import Namespace
from utils import *
from functions import *
from sklearn import decomposition
from sklearn.preprocessing import StandardScaler
from trainBLMNNmodel_pytorch import trainBLMNNmodel
import os, logging, json
import time
from sklearn.preprocessing import StandardScaler
from shutil import copyfile

# import pdb
##Digital recognition on MNIST dataset
# The feature representation is extracted with the CSVDDNet model
bool_func = lambda x: x in ['True', 'true']

parser_app = argparse.ArgumentParser()
parser_app.add_argument("--config_file", type=str)
app_args = parser_app.parse_args()

parser = argparse.ArgumentParser(description='Train Bayesian Neural Net on MNIST with Variational Inference')
parser.add_argument('--log_name', type=str, nargs='?', action='store', default='test')
parser.add_argument('--runtime', type=str, nargs='?', action='store', default='rt1')
parser.add_argument('--randomize', type=lambda x: x in ['True', 'true'], default=True)
parser.add_argument('--samples_per_class', type=int, default=500)
parser.add_argument('--noise_rates', type=str, default="[0.3]")
parser.add_argument('--PCA_dim', type=int, default=100)
parser.add_argument('--KNN_test', type=int, default=3)
parser.add_argument('--noise_type', type=str, default="symmetric")
parser.add_argument('--KNN_triplet', type=int, default=21)
parser.add_argument('--batch_size', type=int, default=40000)
parser.add_argument('--standardize', type=lambda x: x in ['True', 'true'], default=False)
parser.add_argument('--data_file', type=str)
parser.add_argument('--lr', type=float, default=1e-2)
parser.add_argument('--lr_milestones', type=str, default="[3,5]")
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--data_scalar', type=float, default=1)
parser.add_argument('--use_intraclass_pairs', type=bool_func, default=True)
parser.add_argument('--mu_scalar', type=float, default=0.001)
parser.add_argument('--v_scalar', type=float, default=0.01)
parser.add_argument('--nondiag_scalar', type=float, default=1e-5)
parser.add_argument('--results_dir', type=str, default='train_results')
parser.add_argument('--preprocessed_data', type=str, default=None)
parser.add_argument('--mode', type=str, default='acc')
### Parse argument
with open(app_args.config_file, 'r') as f:
    param_text = ""
    for line in f:
        line = line.strip()
        if len(line) <= 1: continue
        if line[0] == '[' and line[-1] == ']': continue
        param_text += '--' + "\t".join(line.strip().split("=")) + '\t'
    args = parser.parse_args(param_text[:-1].split("\t"))


if not args.randomize:
    manualSeed = 11
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    # if you are suing GPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed(manualSeed)
        torch.cuda.manual_seed_all(manualSeed)
        torch.backends.cudnn.enabled = False 
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

def create_stepsize(init_lr, milestones, num_epochs, gamma=0.5):
    stepsizes = []
    curr_lr = init_lr
    for i in range(num_epochs):
        if i in milestones:
            curr_lr *= gamma
        stepsizes.append(curr_lr)
    return stepsizes

## Loading data
def prepare_data(args):
    print('Loading the dataset...\n')
    datafile = args.data_file
    data = torch.load(datafile)
    trX, trY, teX, teY = data['trX'], data['trY'], data['teX'], data['teY']
    if 'trY_gt' in data:
        trY_gt = data['trY_gt']
    else:
        trY_gt = trY.copy()
    del data

    # trX_mean = trX.mean(axis=0, keepdims=True)
    # trX_std = np.sqrt(trX.var(axis=0, keepdims=True) + 0.01)
    # trX = (trX - trX_mean) / trX_std
    # teX = (teX - trX_mean) / trX_std
    # trX[np.isnan(trX)] = 0
    # teX[np.isnan(teX)] = 0

    num_classes = len(set(trY))

    perN = np.ones((10), dtype='int') * samplesPerClass # #per class
    selected_indices = randomStratifiedSampleData(trY, perN)
    trX, trY = trX[selected_indices, :].copy(), trY[selected_indices].copy()
    trY_gt = trY_gt[selected_indices].copy()

    return (trX, trY, trY_gt), (teX, teY), num_classes

#### Generate random noise
def inject_noise(trY_gt, noiserates):
    trY = trY_gt.copy()
    if len(noiserates) == 1:
        message = "injecting symmetric noise ..."
        print(message), flogger.info(message)
        trY = generateRandomLabelNoise2(trY_gt, noiserates[0])
    elif len(noiserates) > 1:
        message = "Asymmetric noises: {}".format(noiserates)
        print(message), flogger.info(message)
        trY = generateImBalancedRandomLabelNoise2(trY_gt, noiserates)
    return trY

def feature_engineering(trX, trY, teX, teY, args):
    ## Dimensionality reduction with PCA
    print('Dimensionality reduction with PCA...\n')
    n, d = trX.shape
    if (n > d):
        T0, _, _ = myPCA(trX, args.PCA_dim)
    else:
        T0, _, _ = myPCA2(trX, args.PCA_dim)

    s0trX = trX.dot(T0)
    s0teX = teX.dot(T0)

    if args.standardize: 
        scaler = StandardScaler(with_std=True)
        s0trX = scaler.fit_transform(s0trX)
        s0teX = scaler.transform(s0teX)

    [pred0, acc0, _, _] = KNNtest(args.KNN_test, s0trX, trY, s0teX, teY)
    mess = 'PCA Test performance (acc%%): %.2f\n' % (acc0 * 100)
    print(mess), flogger.info(mess)
    return s0trX, s0teX, acc0


class EearlyStopping:
    def __init__(self, patience=10, criterion='max'):
        self.patience = patience
        self.cnt_patience = -1
        self.criterion = criterion
        if self.criterion == 'max':
            self.best_perf = -np.inf
        elif self.criterion == 'min':
            self.best_perf = np.inf
        self.num_epoch = 0
    
    def step(self, curr_perf):
        self.num_epoch += 1
        if self.criterion == 'max':
            if curr_perf > self.best_perf:
                self.cnt_patience = 0
                self.best_epoch = self.num_epoch
                self.best_perf = curr_perf
            else:
                self.cnt_patience += 1
        elif self.criterion == 'min':
            if curr_perf < self.best_perf:
                self.cnt_patience = 0
                self.best_epoch = self.num_epoch
                self.best_perf = curr_perf
            else:
                self.cnt_patience += 1
        if self.cnt_patience == self.patience:
            return True
        return False

## Setting
samplesPerClass = args.samples_per_class 
noiserates = json.loads(args.noise_rates) #with 30% random label noise
preprocessed_data = args.preprocessed_data

option = Namespace()
option.use_cuda = torch.cuda.is_available()
option.dim = args.PCA_dim
option.KNN = args.KNN_triplet # #must-links of each points (default: 21, bigger helps)
option.KNN_te = args.KNN_test
option.lambda_ = 20 #inclass variance
option.data_scalar = args.data_scalar #data scaling is useful in Bayesian DML
option.maxIter = args.num_epochs ##Iteration 2
# option.stepSize = [1/100, 1/100, 1/100, 1/200, 1/400] # learning rate
lr_milestones = json.loads(args.lr_milestones)
option.stepSize = create_stepsize(args.lr, lr_milestones, option.maxIter, 0.5)
print("stepsize: ", option.stepSize)
option.batchsize = args.batch_size # batch SVI : bigger batch helps but needs large memory

option.flag = Namespace()
option.flag.apx = False # true: using accelerating trick
option.flag.fold = 4 #[2 3 4] bigger value->training faster, but with lower performance
otherDML = {'GMML', 'LSSL', 'NCM', 'DMML', 'LCML', 'LMNN'} # comparison with other DML methods

option.mode = args.mode #accuracy
option.use_intraclass_pairs = args.use_intraclass_pairs 
option.samplesPerClass = samplesPerClass
option.noiserates = noiserates
option.mu_scalar = args.mu_scalar
option.v_scalar = args.v_scalar
option.nondiag_scalar = args.nondiag_scalar


if not os.path.exists(args.results_dir):
    os.mkdir(args.results_dir)
savefolder = os.path.join(args.results_dir, args.log_name)
if not os.path.exists(savefolder):
    os.mkdir(savefolder)

config_file = os.path.join(savefolder, args.runtime + '.ini')
if not os.path.exists(config_file):
    copyfile(app_args.config_file, config_file)

fname = args.runtime
flogger = logging.getLogger(name=fname)
flogger.setLevel(logging.INFO)
flogger.propagate = False # not propagate to root flogger (print to sdtout)
f_handler = logging.FileHandler(os.path.join(savefolder, fname + ".log"))
f_format = logging.Formatter('[%(asctime)s] - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
f_handler.setFormatter(f_format)
flogger.addHandler(f_handler)

save_args = dict(vars(option).items())
save_args['flag'] = vars(option.flag)
flogger.info("Param: {}".format(json.dumps(save_args, indent=2, sort_keys=True)))

'''
## PCA lib sklearn
pca = decomposition.PCA(n_components=dim)
# scaler = StandardScaler(with_std=False)
# trX = scaler.fit_transform(trX)
# teX = scaler.transform(teX)

s0trX = pca.fit_transform(trX)
s0teX = pca.transform(teX)
[pred0, acc0, _, _] = KNNtest(KNN_te, s0trX, trY, s0teX, teY)
print('PCA SKLEARN Test performance (acc%%): %.2f\n' % (acc0 * 100))
'''
option.lb_start_index = 0
train_accs, test_accs = [], []
for rt in range(5):
    mess = ' Runtime {}\n'.format(rt)
    print(mess), flogger.info(mess)

    if preprocessed_data in [None, 'None']:
        (trX_origin, trY_origin, trY_gt_origin), (teX, teY), num_classes = prepare_data(args)
        trX, trY, trY_gt = trX_origin, trY_origin, trY_gt_origin
        if args.noise_type == 'asymmetric':
            rng = np.random.RandomState(0)
            noiserates = rng.randint(1, 7, num_classes) / 10

        trY = inject_noise(trY, noiserates)

        true_noiserate = sum(trY != trY_gt) / len(trY)
        message = 'Shape:%s || #Training:%d || #Test:%d || label noise (%%): %.2f\n' %(
            trX.shape, len(trY), len(teY), true_noiserate * 100)
        print(message), flogger.info(message)
        s0trX, s0teX, init_acc = feature_engineering(trX, trY, teX, teY, args)
    else:
        if rt != 0:
            datapath = ".".join(preprocessed_data.split(".")[:-1]) + ".{}".format(rt)
        else: datapath = preprocessed_data
        (s0trX, trY, trY_gt, s0teX, teY, init_acc, noiserates) = torch.load(datapath)
        mess = 'Restored preprocessed data: {}, with noise: {}'.format(datapath, noiserates)
        print(mess), flogger.info(mess)

    mess = "Init performance: {}".format(init_acc)
    print(mess), flogger.info(mess)

    ## finding K-nearest neighbors
    Idx_all = getbatchKNNindex(s0trX, 100, s0trX)
    Idx = Idx_all[:, 1:option.KNN + 1] # first point is itself
    [lp1_all, lp2_all, ln_all] = getLMNNidx(trY, Idx)
    LPN = [lp1_all, lp2_all, ln_all]
    print("#triplets: ", len(ln_all))
    flogger.info("#triplets: ".format(len(ln_all)))
    ## Initialization
    print('Parameter Initialization...\n')
    [A0, mu_0, v_0, I_triuA] = symInitA2(args.PCA_dim, args.mu_scalar, args.v_scalar, args.nondiag_scalar)
    mu_t = mu_0.copy()

    ## Training the BLMNN model
    print('Training the BLMNN model......\n')
    flogger.info('Training the BLMNN model......\n')
    # TODO: mu_0, v_0, I_triuA are all ROW CONTIGUOUS that is different from matlab COL CONTIGUOUS
    tic = time.time()

    checkpointer = EearlyStopping(patience=10, criterion='max')
    model, train_acc, test_acc = trainBLMNNmodel(s0trX, trY, s0teX, teY, mu_0, v_0, mu_t, LPN, I_triuA, option, checkpointer, flogger)
    train_accs.append(train_acc), test_accs.append(test_acc)
    
    acct = evaluate(model.mu, s0trX, s0teX, trY, teY, I_triuA, option)
    mess = 'BLMNN Test performance (acc%%): %.2f\n'%(acct * 100)
    print(mess), flogger.info(mess)
    toc = time.time()
    mess += "\n runtime: {}".format(toc - tic)
    print(mess), flogger.info(mess)

mess = 'Train accs: {}, Test accs: {}'.format(train_accs, test_accs)
print(mess), flogger.info(mess)
