from torchvision import datasets
from torchvision import transforms
 # Set up data loaders
from datasets_ import TripletDataset, TensorDatasetWrapper
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import TensorDataset
import math
from trainer import fit
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from networks import EmbeddingNet, TripletNet
from losses import TripletLoss
from utils_ import *
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold

import argparse
from trainer import train_epoch, test_epoch, test_epoch_pred
import os, logging, json
from shutil import copyfile


parser_app = argparse.ArgumentParser()
parser_app.add_argument("--config_file", type=str)
app_args = parser_app.parse_args()

parser = argparse.ArgumentParser()
## DATA
parser.add_argument('--KNN_triplet', type=int, default=11, help='Number of nearest neighbour to form triplets, bigger helps. Default: \'1200\'.')
parser.add_argument('--PCA_dim', type=int, default=100, help='Default: \'100\'.')
parser.add_argument('--datafile', type=str, default='../deep-bayesian-LMNN/cifar10_featues_SimCLR-0.81.dat')
parser.add_argument('--noise_rates', type=str, help='Default: [0.3]')
parser.add_argument('--noise_type', type=str, default='symmetric', help='Default: symmetric')
parser.add_argument('--samples_per_class', type=int, default=100, help='Default: \'50\'.')
parser.add_argument('--standardize', type=lambda x: x in ['True', 'true'], default=True)
parser.add_argument('--use_LMNN_triplets', type=lambda x: x in ['True', 'true'], default=True, help='Default: \'False\'.')
parser.add_argument('--use_PCA', type=lambda x: x in ['True', 'true'], default=False,help='Use PCA to reduce the feature size. Default: \'False\'.')
parser.add_argument('--preprocessed_data', type=str, default=None, help='Use features extracted from pretrained model as input. Default: \'False\'.')

## Model
parser.add_argument('--KNN_test', type=int, default=3, help='Number of nearest neighbours for determining lable for a datapoint. Default: \'3\'.')
parser.add_argument('--embed_size', type=int, default=100, help='Embedding dimension. Default: \'64\'.')
parser.add_argument('--loss_func', type=str, nargs='?', action='store', default='LMNN', help='Choose likelihood function p(D|w). Options are \'Gaussian_approx\', \'LMNN\'.\
                        Default: \'BBP_results\'.')
parser.add_argument('--use_pred_loss', type=lambda x: x in ['True', 'true'], default=False)
parser.add_argument('--margin', type=float, default=1.0, help='Margin for LMNN likelihood')
parser.add_argument('--use_pretrain', type=lambda x:  x in ['True', 'true'], default=False, help='Use pretrained model. Default: \'False\'.')
parser.add_argument('--hidden_sizes', type=str, default="[512,512]", help='Hidden dimension of neural network. Default: \'[1200]\'.')

## TRAINING
parser.add_argument('--batch_size', type=int, default=512, help='Default: \'32\'.')
parser.add_argument('--epochs', type=int, nargs='?', action='store', default=50,  help='How many epochs to train. Default: 200.')
parser.add_argument('--lr', type=float, nargs='?', action='store', default=1e-5, help='learning rate. Default: 1e-3.')
parser.add_argument('--lr_milestones', type=str, default='[10, 20, 40, 50, 80]')
parser.add_argument('--lr_gamma', type=float, default=0.5)
parser.add_argument('--randomize', type=lambda x: x in ['True', 'true'], default=False)
parser.add_argument('--cross_validation', type=lambda x: x in ['True', 'true'], default=False)

## LOG
parser.add_argument('--models_dir', type=str, nargs='?', action='store', default='models',help='Where to save learnt weights and train vectors. Default: \'models\'.')
parser.add_argument('--results_dir', type=str, nargs='?', action='store', default='results',help='Where to save learnt training plots. Default: \'results\'.')
parser.add_argument('--verbose', type=lambda x:  x in ['True', 'true'], default=False, help='Print values each minibatch. Default: \'False\'.')
parser.add_argument('--plot_embed', type=lambda x: x in ['True', 'true'], default=True, help='Default: \'False\'.')
parser.add_argument('--log_name', type=str, default='DeterNN', help='Default: \'DeterNN\'.')
parser.add_argument('--runtime', type=str, default='0', help='Default: \'0\'.')

### Parse argument
with open(app_args.config_file, 'r') as f:
    param_text = ""
    for line in f:
        line = line.strip()
        if len(line) <= 1: continue
        if line[0] == '[' and line[-1] == ']': continue
        param_text += '--' + "\t".join(line.strip().split("=")) + '\t'
    args = parser.parse_args(param_text[:-1].split("\t"))

models_dir = args.models_dir # Where to save models weights
results_dir = args.results_dir # Where to save plots and error, accuracy vectors
mkdir(models_dir)
mkdir(results_dir)

savefolder = os.path.join(results_dir, args.log_name)
mkdir(str(savefolder))

config_file = os.path.join(savefolder, args.runtime + '.ini')
if not os.path.exists(config_file):
    copyfile(app_args.config_file, config_file)
    
if not args.randomize: 
    manualSeed = 42
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    # if you are suing GPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed(manualSeed)
        torch.cuda.manual_seed_all(manualSeed)
        torch.backends.cudnn.enabled = False 
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

noiserates = json.loads(args.noise_rates)
preprocessed_data = args.preprocessed_data
embed_dim = args.embed_size
batch_size = args.batch_size
dim = args.PCA_dim
loss_fn = TripletLoss(args.margin)
runtime = args.runtime 
hidden_sizes = json.loads(args.hidden_sizes)
cuda = torch.cuda.is_available()
# cuda = False

flogger = get_logger(args.runtime, savefolder)
flogger.info("Param: {}".format(json.dumps(args.__dict__, indent=2, sort_keys=True)))

def prepare_data(args):
    datafile = args.datafile
    data = torch.load(datafile)
    trX, trY, teX, teY = data['trX'], data['trY'], data['teX'], data['teY']
    if 'trY_gt' in data:
        trY_gt = data['trY_gt']
    else:
        trY_gt = trY.copy()
    del data

    num_classes = len(set(trY_gt))
    #### Random sample small subset of data for training
    perN = np.ones((num_classes), dtype='int') * args.samples_per_class # #per class
    selected_indices = randomStratifiedSampleData(trY, perN)
    trX = trX[selected_indices, :].copy()
    trY = trY[selected_indices].copy()
    trY_gt = trY_gt[selected_indices].copy()

    return (trX, trY, trY_gt), (teX, teY), num_classes

def inject_noise(trY_gt, noiserates):
    trY = trY_gt.copy()
    if len(noiserates) == 1:
        message = "injecting symmetric noise ..."
        print(message), flogger.info(message)
        trY = generateRandomLabelNoise2(trY_gt, noiserates[0])
    elif len(noiserates) > 1:
        message = "Asymmetric noises: {}".format(noiserates)
        print(message), flogger.info(message)
        trY = generateImBalancedRandomLabelNoise2(trY_gt, noiserates)
    return trY

def feature_engineering(trX, trY, valX, valY, teX, teY, args):
    #### Dimensionality reduction with PCA
    print('Dimensionality reduction with PCA...\n')

    [_, acc0, _, _] = KNNtest(args.KNN_test, trX, trY, teX, teY)
    mess = 'Initial Test performance (acc%%): %.2f\n' % (acc0 * 100)
    print(mess), flogger.info(mess)

    pca_acc0 = 0.
    if trX.shape[1] > dim:
        trX_pca, valX_pca, teX_pca = Wrapper_PCA(trX, valX, teX, dim)
        [_, pca_acc0, _, _] = KNNtest(args.KNN_test, trX_pca, trY, teX_pca, teY)
        mess = 'PCA Test performance (acc%%): %.2f\n' % (pca_acc0 * 100)
        print(mess), flogger.info(mess)

    if args.use_PCA:
        s0trX, s0valX, s0teX = trX_pca, valX_pca, teX_pca
    else:
        s0trX, s0valX, s0teX = trX, valX, teX

    if args.standardize:
        print("rescale data ...")
        mean = s0trX.mean(axis=0, keepdims=True)
        std = s0trX.std(axis=0, keepdims=True)
        std[std == 0] = 1

        s0trX = (s0trX - mean)/std
        s0teX = (s0teX - mean)/std
        if s0valX is not None:
            s0valX = (s0valX - mean)/std
        
        [_, acc0, _, _] = KNNtest(args.KNN_test, s0trX, trY, s0teX, teY)
        mess = 'Initial Test performance rescale (acc%%): %.2f\n' % (acc0 * 100)
        print(mess), flogger.info(mess)
    return s0trX, s0valX, s0teX, (pca_acc0, acc0)

###===================###
### Creating triplets ###
###===================###
def create_dataloaders(feature, target, train=False, transform=None, use_target=False, LMNN_triplets=False, KNN_triplet=21, batch_size=256, shuffle=False):
    if LMNN_triplets:
        triplets_indices = find_KNN_triplets(feature, target, KNN_triplet)
    else:
        triplets_indices = None

    feature, target = torch.tensor(feature), torch.tensor(target, dtype=torch.int64)
    dataset = TensorDataset(feature, target)
    if train:
        kwargs = {}
    else:
        kwargs = {'pin_memory': True} if cuda else {}
    triplet_dataset = TripletDataset(dataset, train=True, transform=transform, use_target=use_target, 
                            triplets_index=triplets_indices)
    triplet_loader = torch.utils.data.DataLoader(triplet_dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
    if type(dataset) == TensorDataset:
        dataset_wrapper = TensorDatasetWrapper(dataset, transform=transform)
    else:
        dataset_wrapper = train_dataset
    loader = torch.utils.data.DataLoader(dataset_wrapper, batch_size=batch_size, shuffle=False, **kwargs)
    return dataset, triplet_loader, loader


def create_network(args, feature_size, embed_dim, num_classes):
    embedding_net = EmbeddingNet(feature_size, embed_dim, hidden_sizes=hidden_sizes, num_classes=num_classes, use_pred_loss=args.use_pred_loss)
    model = TripletNet(embedding_net)
    if cuda: model.cuda()
    lr = args.lr
    milestones = json.loads(args.lr_milestones)
    lr_gamma = args.lr_gamma 

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=lr_gamma, last_epoch=-1)
    return model, optimizer, scheduler


def train(args, model, optimizer, scheduler, triplet_train_loader, triplet_val_loader, triplet_test_loader, 
        train_loader, val_loader, test_loader, loss_fn=None, checkpointer=None, log_interval=100, start_epoch=0):
    n_epochs = args.epochs
    metrics=[]
    if start_epoch > 0:
        for epoch in range(0, start_epoch):
            scheduler.step()
    imgs_train, imgs_eval = [], []
    train_accs, eval_accs = [], []

    for epoch in range(start_epoch, n_epochs):
        # Train stage
        train_loss, metrics, train_bad_triplets = train_epoch(triplet_train_loader, model, loss_fn, optimizer, cuda, 
                                                        log_interval, metrics, use_pred_loss=args.use_pred_loss)
        message = 'Epoch: {}/{}. Train loss: {:.4f}'.format(epoch, n_epochs, train_loss)
        for metric in metrics:
            message += '\t{}: {}'.format(metric.name(), metric.value())
        message += ', #train_bad_triplets: {:.4f}'.format(train_bad_triplets)

        eval_triplet_loader = triplet_val_loader
        eval_loader = val_loader

        eval_loss, metrics, eval_bad_triplets = test_epoch(eval_triplet_loader, model, loss_fn, cuda, 
                                                        metrics, use_pred_loss=args.use_pred_loss)
    
        eval_loss /= len(eval_triplet_loader)
        message += ' {} loss: {:.4f}'.format(mode, eval_loss)
        for metric in metrics:
            message += '\t{}: {}'.format(metric.name(), metric.value())
        message += ', #eval_bad_triplets: {:.4f}'.format(eval_bad_triplets)
        
        if args.use_pred_loss:
            eval_pred_acc = test_epoch_pred(eval_loader, model, loss_fn, cuda)
            message += ', eval_pred_acc: {:.4f}'.format(eval_pred_acc)

        print(message), flogger.info(message)
        scheduler.step()
        
        if epoch % 1 == 0:
            train_embeddings_tl, train_labels_tl = extract_embeddings(train_loader, model, embed_dim=embed_dim, cuda=cuda)
            eval_embeddings_tl, eval_labels_tl = extract_embeddings(eval_loader, model, embed_dim=embed_dim, cuda=cuda)

            [_, acc_train, _, _]  = KNNtest(args.KNN_test, train_embeddings_tl, train_labels_tl, train_embeddings_tl, train_labels_tl)
            [_, acc_eval, _, _]  = KNNtest(args.KNN_test, train_embeddings_tl, train_labels_tl, eval_embeddings_tl, eval_labels_tl)
            mess = '  %d-KNN acc: Train: %.2f, %s: %.2f\n'%(3, acc_train*100, mode, acc_eval*100)
            train_accs.append(acc_train), eval_accs.append(acc_eval)
            print(mess), flogger.info(mess)
        
        if args.plot_embed:
            train_img = plot_highest_var_dim(train_embeddings_tl, train_labels_tl, epoch=epoch)
            eval_img = plot_highest_var_dim(eval_embeddings_tl, eval_labels_tl, epoch=epoch)
            imgs_train.append(train_img)
            imgs_eval.append(eval_img)

        early_stop = checkpointer.step(acc_eval)
        if early_stop:
            print("Early Stopped!")
            break
    return train_accs, eval_accs, imgs_train, imgs_eval

class EearlyStopping:
    def __init__(self, net, patience=10, criterion='max', path=None):
        self.patience = patience
        self.cnt_patience = -1
        self.criterion = criterion
        if self.criterion == 'max':
            self.best_perf = -np.inf
        elif self.criterion == 'min':
            self.best_perf = np.inf
        self.num_epoch = 0
        self.path = path
        self.model = model
    
    def step(self, curr_perf):
        self.num_epoch += 1
        if self.criterion == 'max':
            if curr_perf > self.best_perf:
                self.cnt_patience = 0
                self.best_epoch = self.num_epoch
                self.best_perf = curr_perf
                if self.path is not None: self.model.save(self.path)
            else:
                self.cnt_patience += 1
        elif self.criterion == 'min':
            if curr_perf < self.best_perf:
                self.cnt_patience = 0
                self.best_epoch = self.num_epoch
                self.best_perf = curr_perf
                if self.path is not None: self.model.save(self.path)
            else:
                self.cnt_patience += 1
        if self.cnt_patience == self.patience:
            return True
        return False

def save_plots(fold, imgs_train, imgs_eval):
    filepath_img_train = '{}/train.{}.{}.gif'.format(savefolder, runtime, fold)
    filepath_img_eval = '{}/eval.{}.{}.gif'.format(savefolder, runtime, fold)
    imageio.mimsave(filepath_img_train, imgs_train, fps=1)
    imageio.mimsave(filepath_img_eval, imgs_eval, fps=1)


(trX_origin, trY_origin, trY_gt_origin), (teX, teY), num_classes = prepare_data(args)
if args.noise_type == 'asymmetric':
    rng = np.random.RandomState(0)
    noiserates = rng.randint(1, 7, num_classes) / 10

### K-Fold cross validation
skf = StratifiedKFold(n_splits=5)
eval_accs = []
eval_epochs = []
for fold_idx, (train_index, val_index) in enumerate(skf.split(trX_origin, trY_gt_origin)):
    print("###Fold {}: ".format(fold_idx))
    trX, trY_gt = trX_origin[train_index], trY_gt_origin[train_index]
    valX, valY = trX_origin[val_index], trY_gt_origin[val_index]


    trY = inject_noise(trY_gt, noiserates)

    true_noiserate = sum(trY != trY_gt) / len(trY)
    message = 'Shape:%s || #Training:%d || #Val: %d || #Test:%d || label noise (%%): %.2f\n' %(
        trX.shape, len(trY), len(valY), len(teY), true_noiserate * 100)
    print(message), flogger.info(message)

    s0trX, s0valX, s0teX, init_acc = feature_engineering(trX, trY, valX, valY, teX, teY, args)
    feature_size = s0trX.shape[1]
    message = 'feature dim: %d, embed_dim: %d' % (feature_size, embed_dim)
    cprint('c', message), flogger.info(message)

    # Dataloaders
    train_dataset, triplet_train_loader, train_loader = create_dataloaders(s0trX, trY, train=True, transform=None, use_target=args.use_pred_loss, 
                            LMNN_triplets=True, KNN_triplet=args.KNN_triplet, batch_size=batch_size, shuffle=True)
    val_dataset, triplet_val_loader, val_loader = create_dataloaders(s0valX, valY, train=False, transform=None, use_target=False, 
                                LMNN_triplets=False, KNN_triplet=args.KNN_triplet, batch_size=batch_size, shuffle=False)
    test_dataset, triplet_test_loader, test_loader = create_dataloaders(s0teX, teY, train=False, transform=None, use_target=False, 
                                LMNN_triplets=False, KNN_triplet=args.KNN_triplet, batch_size=batch_size, shuffle=False)

    print('\nNtriplets: %d'%(len(triplet_train_loader.dataset)))
    flogger.info('Ntriplets: %d'%(len(triplet_train_loader.dataset)))
    nb_batches = len(triplet_train_loader)

    # embed_dim = min(feature_size, embed_dim)

    model, optimizer, scheduler = create_network(args, feature_size, embed_dim, num_classes)
    if train_dataset.tensors[0].dtype == torch.float32:
        model.float()
    elif train_dataset.tensors[0].dtype == torch.float64:
        model.double()
    else:
        print("datatype not expected: {}".format(train_dataset.tensors[0].dtype))
        exit(1)

    if fold_idx == 0:
        cprint('c', '\nNetwork:')
        print(str(model))
        flogger.info(str(model))

    tic0 = time.time()
    checkpointer = EearlyStopping(model, patience=10, criterion='max')
    train_acc_rt, eval_acc_rt, imgs_train, imgs_eval = train(args, model, optimizer, scheduler, 
            triplet_train_loader, triplet_val_loader, triplet_test_loader,
            train_loader, val_loader, test_loader, loss_fn, checkpointer)
    if args.plot_embed:
        save_plots(fold_idx, imgs_train, imgs_eval)

    toc0 = time.time()
    runtime_per_it = (toc0 - tic0) / float(args.epochs)
    print('Best validation acc: {} at epoch {}'.format(checkpointer.best_perf, checkpointer.best_epoch))
    cprint('r', ' total runtime: %.2f,  average time: %.2f seconds\n' % (toc0-tic0, runtime_per_it))
    flogger.info(' total runtime: %.2f,  average time: %.2f seconds\n' % (toc0-tic0, runtime_per_it))
    eval_accs.append(checkpointer.best_perf)
    eval_epochs.append(checkpointer.best_epoch)

message = "Valdiation Accuracies: {}, average: {}".format(eval_accs, np.array(eval_accs).mean())
print(message), flogger.info(message)
message = "Evaluation at: {}".format(eval_epochs)
print(message), flogger.info(message)