"""
Calculate the visual embedding of animal pose and appearance.

"""

from apca.models import AAPCA
import argparse
import joblib
import numpy as np
import os
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score

from src.config_utils import Config

PRE_PCA_COMPONENTS = 2000
PCA_COMPONENTS = 50


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train script for the model")
    parser.add_argument("config", type=str, help="Path to the config JSON file")
    args = parser.parse_args()
    
    config = Config(args.config)

    embed = np.load(config.feature_fn)

    embed = embed.reshape(len(embed), -1)
    print("features:", embed.shape)

    d = np.load(config.center_rotation_fn)
    angles = d["angles"]
    centers = d["centers"]
    angles2 = np.stack(
        [
            np.cos(angles),
            np.sin(angles),
        ],
        axis=1,
    )

    embed -= np.mean(embed, axis=0, keepdims=True)

    print("Doing Pre-PCA...")
    pca = PCA(PRE_PCA_COMPONENTS, random_state=42)
    embed = pca.fit_transform(embed)
    joblib.dump(pca, os.path.join(config.project_directory, 'pca_model.joblib'))

    print("Doing adversarial PCA...")
    aapca = AAPCA(PCA_COMPONENTS, mu=1e2, pow_iter=20, random_state=42)
    temp_embed = aapca.fit_transform(embed, angles2)
    _, rec_angles = aapca.reconstruct(embed, angles2)
    print("r2", r2_score(angles2, rec_angles))
    embed = temp_embed
    joblib.dump(aapca, os.path.join(config.project_directory, 'aapca_model.joblib'))

    # Save.
    np.save(os.path.join(config.project_directory, "embedding.npy"), embed)


###