import math
import numpy as np
import tensorflow as tf
import argparse
import random
import os.path
import matplotlib.pyplot as plt

from data import get_data_generator
from model import get_model_generator
from evaluator import get_evaluator


def set_random_seeds(parameter_random_seed, data_random_seed):
    random.seed(data_random_seed)
    np.random.seed(data_random_seed)
    tf.random.set_seed(parameter_random_seed)


def train(args, dg, model, ev, fixed_train_data, test_data):
    print(0, *ev.evaluate_all())
    data_points = [analyze_model(args, fixed_train_data, test_data, ev)]
    for i in range(args.steps):
        x_train, y_train = dg.get_training_samples(args.batch_size)
        model.fit(x_train, y_train, batch_size=args.batch_size, epochs=1,
                  verbose=0)
        if i % args.log_interval == args.log_interval - 1:
            print(i + 1, *ev.evaluate_all())
        data_points.append(analyze_model(args, fixed_train_data, test_data, ev))
    print("final", *ev.large_evaluate_all())
    return np.asarray(data_points)


def unified_ddr(a, b, c):
    assert len(a) == len(c)
    ac = np.dot(a, c)
    assert len(b) == len(c)
    bc = np.dot(b, c)

    if ac == 0:
        if bc > 0:
            return math.inf
        elif bc < 0:
            return -math.inf
        else:
            return 0

    if ac < 0:
        ac = -ac
    return bc / ac


def analyze_model(args, train, test, model):
    n_train = len(train[0])
    n_test = len(test[0])
    assert n_train > 0
    assert n_test > 0

    train_grad, train_loss, train_acc = model.get_gradient_loss_acc(train)
    test_grad, test_loss, test_acc = model.get_gradient_loss_acc(test)
    assert len(train_grad) == len(test_grad)
    all_grad = (n_train * train_grad + n_test * test_grad) / (n_train + n_test)

    # loss changes
    a, b, c = train_grad, test_grad, all_grad
    aa = np.dot(a, a)
    ab = np.dot(a, b)
    ca = np.dot(c, a)
    cb = np.dot(c, b)

    rdd_train = unified_ddr(train_grad, test_grad, train_grad)
    rdd_all = unified_ddr(train_grad, test_grad, all_grad)

    # loss and acc
    return train_loss, test_loss, train_acc, test_acc, rdd_train, rdd_all, aa, ab, ca, cb


def plot_results(args, data, name, legends):
    for d, l in zip(data, legends):
        plt.plot(d, label=l)
    plt.legend()

    os.makedirs(args.log_dir, exist_ok=True)
    plt.savefig(os.path.join(args.log_dir, name + ".pdf"))
    plt.clf()


def pivot_results(args, data, name, legends):
    assert len(data) == 2
    assert len(legends) == 2
    plt.scatter(data[0], data[1])

    os.makedirs(args.log_dir, exist_ok=True)
    plt.savefig(os.path.join(args.log_dir, 'pivot_' + name + ".pdf"))
    plt.clf()


def smooth_rdd(data):
    data = np.sign(data) * np.log(np.abs(data) + 1)

    window = 50
    original_length = data.shape[-1]
    smoothed_length = original_length // window
    length = smoothed_length * window
    if length < original_length:
        data = np.split(data, [length, -1], -1)[0]
    data = np.reshape(data, [-1, smoothed_length, window])
    data = np.average(data, -1)
    return data


def main(args):
    # get data
    set_random_seeds(42, 43)
    dg = get_data_generator(args)
    fixed_train_data = dg.get_training_samples(args.test_sample_size)
    eval_data = dg.get_training_samples_for_evaluation(args.test_sample_size)
    eval_new_data = dg.get_eval_samples(args.test_sample_size)
    test_data = dg.get_test_samples(args.test_sample_size, randomize=False)
    random_data = dg.get_test_samples(args.test_sample_size, randomize=True)
    large_eval_data = dg.get_training_samples_for_evaluation(
        10 * args.test_sample_size)
    large_eval_new_data = dg.get_eval_samples(10 * args.test_sample_size)
    large_test_data = dg.get_test_samples(10 * args.test_sample_size,
                                          randomize=False)
    large_random_data = dg.get_test_samples(10 * args.test_sample_size,
                                            randomize=True)

    # set random seeds
    set_random_seeds(args.parameter_random_seed, args.data_random_seed)

    # get model and evaluator
    output_nodes = dg.get_output_nodes()
    mg = get_model_generator(args, dg.get_input_shape(), output_nodes)
    mg.set_vocab_size(dg.get_vocab_size())
    model, loss_fn = mg.get_model()
    with open(os.path.join(args.log_dir, 'model_summary.txt'), 'w') as f:
        model.summary(print_fn=f.write)

    test_label_pairs = dg.get_test_label_pairs()
    ev = get_evaluator(args, model, loss_fn,
                       [eval_data, test_data, random_data, eval_new_data],
                       [large_eval_data, large_test_data, large_random_data,
                        large_eval_new_data],
                       test_label_pairs)

    # train and evaluate
    data_points = train(args, dg, model, ev, fixed_train_data, test_data)

    # plot results
    os.makedirs(args.log_dir, exist_ok=True)
    np.save(os.path.join(args.log_dir, 'rdd.npy'), data_points)

    data_points = np.transpose(data_points)

    legends = ['loss-train', 'loss-test']
    plot_results(args, data_points[:2], 'loss', legends)
    pivot_results(args, data_points[:2], 'loss', legends)

    legends = ['acc-train', 'acc-test']
    plot_results(args, data_points[2:4], 'acc', legends)
    pivot_results(args, data_points[2:4], 'acc', legends)

    rdd_data = smooth_rdd(data_points[4:6])
    legends = ['rdd-train', 'rdd-test']
    plot_results(args, rdd_data, 'rdd', legends)

    legends = ['aa', 'ab', 'ca', 'cb']
    plot_results(args, data_points[6:], 'changes', legends)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir', type=str, default='',
                        help='Log directory.')
    parser.add_argument('--data_random_seed', type=int, default=8,
                        help='Random seed.')
    parser.add_argument('--parameter_random_seed', type=int, default=7,
                        help='Random seed.')
    parser.add_argument('--depth', type=int, default=6,
                        help='Number of layers.')
    parser.add_argument('--n_hidden_nodes', type=int, default=32,
                        help='Number of nodes in hidden layer.')
    parser.add_argument('--loss_type', type=str, default='cross_entropy',
                        help='Loss type.')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='Batch size.')
    parser.add_argument('--test_sample_size', type=int, default=1000,
                        help='Test sample size.')
    parser.add_argument('--log_interval', type=int, default=10,
                        help='Log interval.')
    parser.add_argument('--steps', type=int, default=500,
                        help='Steps.')
    parser.add_argument('--combined_labels', type=int, default=3,
                        help='Combined labels.')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate.')
    parser.add_argument('--merge_type', type=str, default='colored',
                        help='Merge type.')
    parser.add_argument('--evaluator_type', type=str, default='normal',
                        help='Evaluator type.')
    parser.add_argument('--save_image', action='store_true', default=False,
                        help='Show image and stop.')
    parser.add_argument('--adversarial', action='store_true',
                        default=False,
                        help='Use adversarial learning on test.')
    parser.add_argument('--dataset', type=str, default='mnist',
                        help='Dataset.')
    parser.add_argument('--dataset1', type=str, default='reuters',
                        help='Dataset 1.')
    parser.add_argument('--dataset2', type=str, default='reuters',
                        help='Dataset 2.')
    parser.add_argument('--any_generalization', action='store_true',
                        default=False, help='Any systematic generalization.')
    parser.add_argument('--model_type', type=str, default='cnn',
                        help='Model type.')
    parser.add_argument('--input_permutation', action='store_true',
                        default=False, help='Permute input.')
    parser.add_argument('--label_split', type=str, default='tile',
                        help='Model type.')
    parser.add_argument('--rotate_second_input', action='store_true',
                        default=False, help='Rotate second input.')
    parser.add_argument('--pretrain', action='store_true',
                        default=False, help='Rotate second input.')
    parser.add_argument('--dataset_dir', type=str,
                        default='../../data/zeroshot_datasets',
                        help='Zero-shot dataset directory.')
    parser.add_argument('--partition_threshold_percentage', type=int, default=50,
                        help='Partition threshold percentage for evaluation.')
    parser.add_argument('--mcd_split', type=str, default='mcd1', help='MCD split.')
    parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer.')
    parser.add_argument('--use_all_training_data', action='store_true',
                        default=False, help='Use all training data.')
    main(parser.parse_args())
