import argparse
import os

import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from evaluation.utils import get_nfeatures_seqlen_maxlag, get_runid_and_epoch, get_data_real, get_data_synthetic, \
    get_data_real_crvae

FONT_SIZE_TITLE = 28
FONT_SIZE_LEGEND = 22
FONT_SIZE_TICKS = 16
FONT_SIZE_LABELS = 18


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--run', required=True)
    parser.add_argument('-d', '--dataset', required=True)
    args = parser.parse_args()

    run = args.run
    dataset = args.dataset
    seed = 0

    d = {
        'no_additional_losses': 'Our',
        'l2': 'Our w/L2',
        'l2_dtw': 'Our w/L2 w/DTW',
        'l1_dtw': 'Our w/L1 w/DTW',
        'l2_fourier': 'Our w/L2 w/Fourier',
        'crvae': 'CR-VAE',
        'causaltime4010': 'CausalTime',
        'causaltime2010': 'CausalTime',
        'causaltime105': 'CausalTime'
    }
    title = d[run]

    n_features, seq_len, _ = get_nfeatures_seqlen_maxlag(dataset)

    n_samples_evaluation = 1024

    run_id, epoch = get_runid_and_epoch(run, dataset)

    print(f'Dataset: {dataset}\tRun: {run}\tID: {run_id}\tEpoch: {epoch}')

    if 'crvae' not in run:
        data_real, _ = get_data_real(run, dataset, seq_len, 'tsne_pca')
    else:
        data_real = get_data_real_crvae(run, dataset, epoch, seed)

    data_synthetic = get_data_synthetic(run_id, epoch, seed, seq_len, n_features)

    data_real_ev = data_real[np.random.randint(0, len(data_real), n_samples_evaluation)]
    data_synthetic_ev = data_synthetic[np.random.randint(0, len(data_synthetic), n_samples_evaluation)]

    data_real_ev = np.mean(data_real_ev, 2)
    data_synthetic_ev = np.mean(data_synthetic_ev, 2)

    data_ev = np.concatenate((data_real_ev, data_synthetic_ev), axis=0)

    # t-SNE
    tsne = TSNE(n_components=2, verbose=0, perplexity=40, max_iter=300)
    tsne_results = tsne.fit_transform(data_ev)
    plt.scatter(
        tsne_results[:n_samples_evaluation, 0], tsne_results[:n_samples_evaluation, 1],
        c=['red'] * n_samples_evaluation, alpha=0.2, label="Original"
    )
    plt.scatter(
        tsne_results[n_samples_evaluation:, 0], tsne_results[n_samples_evaluation:, 1],
        c=['blue'] * n_samples_evaluation, alpha=0.2, label="Synthetic"
    )
    plt.legend(fontsize=FONT_SIZE_LEGEND)
    plt.title(title, fontsize=FONT_SIZE_TITLE)
    plt.xlabel('x t-SNE', fontsize=FONT_SIZE_LABELS)
    plt.ylabel('y t-SNE', fontsize=FONT_SIZE_LABELS)
    plt.tick_params('both', labelsize=FONT_SIZE_TICKS)
    plt.tight_layout()

    dir = f'experiments/dimensionality_reduction/tsne/{dataset}'
    file_path = f'{dir}/{run}.pdf'
    os.makedirs(dir, exist_ok=True)
    plt.savefig(file_path)
    plt.close()
    print(f'Seed: {seed}\tSaved: {file_path}')

    # PCA
    pca = PCA(n_components=2)
    pca_results = pca.fit_transform(data_ev)
    plt.scatter(
        pca_results[:n_samples_evaluation, 0], pca_results[:n_samples_evaluation, 1],
        c=['red'] * n_samples_evaluation, alpha=0.2, label="Original"
    )
    plt.scatter(
        pca_results[n_samples_evaluation:, 0], pca_results[n_samples_evaluation:, 1],
        c=['blue'] * n_samples_evaluation, alpha=0.2, label="Synthetic"
    )
    plt.legend(fontsize=FONT_SIZE_LEGEND)
    plt.title(title, fontsize=FONT_SIZE_TITLE)
    plt.xlabel('x-PCA', fontsize=FONT_SIZE_LABELS)
    plt.ylabel('y-PCA', fontsize=FONT_SIZE_LABELS)
    plt.tick_params('both', labelsize=FONT_SIZE_TICKS)
    plt.tight_layout()

    dir = f'experiments/dimensionality_reduction/pca/{dataset}'
    os.makedirs(dir, exist_ok=True)
    file_path = f'{dir}/{run}.pdf'
    plt.savefig(file_path)
    plt.close()
    print(f'Seed: {seed}\tSaved: {file_path}')


if __name__ == '__main__':
    main()
