import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
from mpl_toolkits.mplot3d import axes3d
from numpy.linalg import svd
from pylab import *


def standardize(X):
    Xbar = X - np.mean(X, axis=0)
    return Xbar / np.std(Xbar, axis=0)


def pca(X, k):
    if X.shape[1] == k:
        return standardize(X)
    Xbar = X - np.mean(X, axis=0)
    U, S, V = svd(Xbar)
    return np.dot(U[:, 0:k], np.diag(sqrt(S[0:k])))


def tonumpy(x):
    return x.to("cpu").detach().numpy().copy()


def _define_numvecs_ratio(data, cap):
    cap *= len(data)
    freq_vecs = []
    for _data in data:
        X = np.array([tonumpy(xx) for xx in _data])
        freq_vecs.append(X.shape[0])
    num_vecs = [int(cap * freq_vec / sum(freq_vecs)) for freq_vec in freq_vecs]

    for num_vec, freq_vec in zip(num_vecs, freq_vecs):
        assert num_vec <= freq_vec, "VecsAreNotEnough: consider more smaller cap"

    return num_vecs


def _define_numvecs_cap(data, cap):
    num_vecs = []
    for _data in data:
        X = np.array([tonumpy(xx) for xx in _data])
        num_vecs.append(min(X.shape[0], cap))

    return num_vecs


def analyze(data, output, cap=100, define_ratio=True):
    num_vecs = []
    num_vecs_defined = False
    _Y = []
    if len(data) > 1:
        if define_ratio:
            num_vecs = _define_numvecs_ratio(data, cap)
        else:
            num_vecs = _define_numvecs_cap(data, cap)
        num_vecs_defined = True

    for data_id, _data in enumerate(data):
        X = np.array([tonumpy(xx) for xx in _data])
        if num_vecs_defined:
            num_vec = num_vecs[data_id]
        else:
            num_vec = min(X.shape[0], cap)
            num_vecs.append(num_vec)
        _Y.append(X[0:num_vec, :])
    _Y = np.vstack(_Y)

    D = _Y.shape[1]
    if D >= 3:
        _Y = pca(_Y, 3)
    elif D == 2:
        _Y = pca(_Y, 2)
    else:
        exit(0)

    fig = figure()
    ax = fig.add_subplot(1, 1, 1, projection="3d")

    curr = 0
    for data_id, num_vec in enumerate(num_vecs):
        Y = _Y[curr : curr + num_vec, :]
        if D >= 3:
            ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2])
        else:
            ax.scatter(Y[:, 0], Y[:, 1])
        np.savetxt(f"{output}_{data_id}.dat", Y[:, :2])
        curr += num_vec

    file = f"{output}.png"
    plt.savefig(file, bbox_inches="tight")
    plt.close()
    plt.clf()


def load(file):
    return pickle.load(open(file, "rb"))


def usage():
    print("usage: % preprocess.py cap data_0.pkl (data_1.pkl ... data_T.pkl)")
    sys.exit(0)


def main():
    if len(sys.argv) < 3:
        usage()
    cap = int(sys.argv[1])
    data = []
    for i in range(2, len(sys.argv)):
        _data = load(sys.argv[i])
        data.append(_data)
    for word in data[0].keys():
        _word = "_".join(word.split())
        analyze([_data[word] for _data in data], _word, cap, define_ratio=False)


if __name__ == "__main__":
    main()
