import os
import argparse
import numpy as np
import pylab
from tsne import tsne


def load_file(fn):
    with open(fn, 'r') as f:
        lines = f.readlines()
    X1 = []
    X2 = []
    for line in lines:
        terms = line.strip().split('\t')
        x1 = terms[2].strip().split(' ')
        x1 = [float(x) for x in x1]
        x2 = terms[3].strip().split(' ')
        x2 = [float(x) for x in x2]
        X1.append(x1)
        X2.append(x2)
    return X1, X2


def main(args):
    train_X1, train_X2 = load_file(
        os.path.join(args.dir_path, args.input_train_file))
    test_X1, test_X2 = load_file(
        os.path.join(args.dir_path, args.input_test_file))

    spl = len(train_X1)

    X1 = train_X1 + test_X1
    X2 = train_X2 + test_X2

    X1 = np.array(X1)
    X2 = np.array(X2)

    Y1 = tsne(X1, 1, 50, 20.0)
    Y2 = tsne(X2, 1, 50, 20.0)

    Y1 = Y1[:, 0]
    Y2 = Y2[:, 0]

    fontsize = 18
    pylab.tick_params(axis='both', which='major', labelsize=fontsize)
    pylab.scatter(Y1[:spl], Y2[:spl], 30, color='blue', marker='o',
                  label='Train')
    pylab.scatter(Y1[spl:], Y2[spl:], 30, color='orange', marker='o',
                  label='Test')
    pylab.xlabel('First representation', fontsize=fontsize)
    pylab.ylabel('Second representation', fontsize=fontsize)

    # output
    ofn = os.path.join(args.dir_path, args.output_prefix + '.png')
    pylab.savefig(ofn, bbox_inches='tight')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir_path', type=str,
                        default='../logs/overlap_proposed_A',
                        help='dictionary.')
    parser.add_argument('--input_train_file', type=str,
                        default='eval_hidden.txt',
                        help='input train file.')
    parser.add_argument('--input_test_file', type=str,
                        default='test_hidden.txt',
                        help='input train file.')
    parser.add_argument('--output_prefix', type=str, default='hidden_layer',
                        help='dictionary file.')
    args = parser.parse_args()

    main(args)
