import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import argparse
from tqdm import tqdm, trange
import math
import numpy as np
from scipy.sparse.linalg.eigen.arpack import eigsh
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

from mnist import MNIST
from model import ConvYu
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--resume', default=None, type=str)
parser.add_argument('--maxiter', default=20000, type=int)
parser.add_argument('--keep-prob',  default=1.0, type=float)
parser.add_argument('--lr',  default=0.07, type=float)
parser.add_argument('--scale',  default=5.0, type=float)
parser.add_argument('--lead', default=100, type=int)
parser.add_argument('--batch-size', default=2000, type=int)
parser.add_argument('--frequence', default=10, type=int)
parser.add_argument('--datadir', default='../data/fashion_tf_1k_cor', type=str)
parser.add_argument('--logdir', default='logs/gld_lead_nosqrt_clean_5.0_10_100_seed2', type=str)

if __name__ == '__main__':
    args = parser.parse_args()

    tf.random.set_random_seed(2)

    # data loader
    train_set = MNIST(os.path.join(args.datadir, 'clean_train.npz'))
    clean_train_set = MNIST(os.path.join(args.datadir, 'clean_train.npz'))
    val_set = MNIST(os.path.join(args.datadir, 'val.npz'))
    val_loader = Loader(val_set, batch_size=2000, shuffle=False)
    one_loader = Loader(train_set, batch_size=args.batch_size, shuffle=True)
    clean_train_loader = Loader(clean_train_set, batch_size=args.batch_size, shuffle=True)

    # model
    model = ConvYu()

    # summary
    _loss = tf.placeholder(tf.float32)
    _acc = tf.placeholder(tf.float32)
    _trhessian = tf.placeholder(tf.float32)
    _trcov = tf.placeholder(tf.float32)
    _trhescov = tf.placeholder(tf.float32)
    _maxecov = tf.placeholder(tf.float32)
    _minecov = tf.placeholder(tf.float32)
    _ratioecov = tf.placeholder(tf.float32)

    train_summary_list = [tf.summary.scalar('loss/train', _loss),
                          tf.summary.scalar('acc/train', _acc),
                          tf.summary.scalar('maxecov/train', _maxecov),
                          tf.summary.scalar('minecov/train', _minecov),
                          tf.summary.scalar('ratioecov/train', _ratioecov),
                          # tf.summary.scalar('trhessian/train', _trhessian),
                          # tf.summary.scalar('trcov/train', _trcov),
                          # tf.summary.scalar('trhescov/train', _trhescov),
                          ]
    train_summary_merged = tf.summary.merge(train_summary_list)
    val_summary_list = [tf.summary.scalar('loss/val', _loss),
                        tf.summary.scalar('acc/val', _acc)]
    val_summary_merged = tf.summary.merge(val_summary_list)
    clean_train_summary_list = [tf.summary.scalar('loss/clean_train', _loss),
                                tf.summary.scalar('acc/clean_train', _acc)]
    clean_train_summary_merged = tf.summary.merge(clean_train_summary_list)
    _loss_gap = tf.placeholder(tf.float32)
    _acc_gap = tf.placeholder(tf.float32)
    gap_summary_list = [tf.summary.scalar('loss/gap', _loss_gap),
                        tf.summary.scalar('acc/gap', _acc_gap)]
    gap_summary_merged = tf.summary.merge(gap_summary_list)

    _ecov = tf.placeholder(tf.float32, [args.lead])
    _vcov = tf.placeholder(tf.float32, [model.n_weights, args.lead])
    _l2norm_cov = tf.placeholder(tf.float32)
    _trnorm_cov = tf.placeholder(tf.float32)

    matrix_summary_list = [tf.summary.scalar('matrix/l2_norm/cov', _l2norm_cov),
                           tf.summary.scalar('matrix/tr_norm/cov', _trnorm_cov),
                           tf.summary.scalar('leadingPer', tf.reduce_sum(_ecov)/tf.square(_trnorm_cov))]
    for i in range(args.lead):
        matrix_summary_list.append(tf.summary.scalar('eigenValue/lead'+str(i+1)+'/ecov', _ecov[i]))
    matrix_summary_merged = tf.summary.merge(matrix_summary_list)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.2

    # prepare
    saver = tf.train.Saver(max_to_keep=100)
    with tf.Session(config=config) as sess:
        # initialize
        sess.run(tf.global_variables_initializer())

        # resume
        if args.resume is not None:
            saver.restore(sess, args.resume)
            print("Model restored.")

        # writer
        writer = tf.summary.FileWriter(make_dir(args.logdir), sess.graph)

        # training and eval
        for epoch in trange(args.maxiter):

            noise = np.zeros(model.n_weights)
            if epoch % args.frequence == 0:
                # eval matrix
                data_tuple = (train_set.X, train_set.Y)
                # hessian = eval_Hess(sess, [data_tuple], model)
                covariance, _ = eval_Cov(sess, one_loader(), model)
                l2norm_cov = np.linalg.norm(covariance, ord='fro')
                trnorm_cov = np.sqrt(covariance.trace())
                ecov, vcov = eigsh(covariance, args.lead, which='LM')
                summary = sess.run(matrix_summary_merged, feed_dict={
                    _l2norm_cov: l2norm_cov,
                    _trnorm_cov: trnorm_cov,
                    _ecov: ecov})
                writer.add_summary(summary, epoch)
                # trhessian = (hessian).trace()
                # trcov = (covariance).trace()
                # trhescov = (hessian * covariance).trace()
                # noise = np.matmul(vcov, np.random.normal(0, np.sqrt(np.abs(ecov))))
                noise = args.scale * np.matmul(vcov, np.random.normal(0, np.sqrt(np.sqrt(np.abs(ecov)))))

                maxecov = np.max(ecov)
                minecov = np.abs(np.min(ecov)) + 1e-10
                ratioecov = maxecov / minecov

            loss, acc = train(sess, [data_tuple], model, args.lr, args.keep_prob, noise)
            summary = sess.run(train_summary_merged, feed_dict={_loss:loss, _acc:acc, _minecov:minecov, _maxecov:maxecov, _ratioecov:ratioecov})
            writer.add_summary(summary, epoch)
            print('Epoch: {:}    loss: {:.6f}    acc: {:.2f}    In Train'.format(epoch, loss, acc))

            if epoch % 1 == 0:
                # validate step
                train_loss, train_acc = validate(sess, clean_train_loader(), model)
                summary = sess.run(clean_train_summary_merged, feed_dict={_loss:train_loss, _acc:train_acc})
                writer.add_summary(summary, epoch)
                print('Epoch: {:}    loss: {:.6f}    acc: {:.2f}    In Clean Train'.format(epoch, train_loss, train_acc))

            if epoch % 1 == 0:
                # validate step
                valid_loss, valid_acc = validate(sess, val_loader(), model)
                summary = sess.run(val_summary_merged, feed_dict={_loss:valid_loss, _acc:valid_acc})
                writer.add_summary(summary, epoch)
                print('Epoch: {:}    loss: {:.6f}    acc: {:.2f}    In Validation'.format(epoch, valid_loss, valid_acc))

                loss_gap = valid_loss - train_loss
                acc_gap = train_acc - valid_acc
                summary = sess.run(gap_summary_merged, feed_dict={_loss_gap: loss_gap, _acc_gap: acc_gap})
                writer.add_summary(summary, epoch)
                print('Epoch: {:}    loss: {:.6f}    acc: {:.2f}    In Gap'.format(epoch, loss_gap, acc_gap))

            # save ckpt
            if epoch % 1 == 0:
                saver.save(sess, os.path.join(args.logdir, 'model'), epoch)