# this script was used to quantitatively compute the accuracy of factorial swap of skin/hair
import torch
import numpy as np

import copy
import argparse

from koopman_cnn import KoopmanCNN
from utils import t_to_np, reorder, set_seed_device, load_checkpoint, load_classifier, load_dataset, static_dynamic_split


def define_args():
    parser = argparse.ArgumentParser(description="Sprites Disentanglement")

    # general
    parser.add_argument('--cuda', action='store_false')
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--batch_size', type=int, default=256, metavar='N')

    # model
    parser.add_argument('--conv_dim', type=int, default=32)
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--rnn', type=str, default='both', choices=['enc', 'dec', 'both'])
    parser.add_argument('--hidden_dim', type=int, default=40, help='#feats for LSTM')

    # Koopman parameters
    parser.add_argument('--k_dim', type=int, default=40)
    parser.add_argument('--static_size', type=int, default=7)
    parser.add_argument('--w_rec', type=float, default=15.0)
    parser.add_argument('--w_pred', type=float, default=1.0)
    parser.add_argument('--w_eigs', type=float, default=1.0)
    parser.add_argument('--dynamic_thresh', type=float, default=0.5)

    # other
    parser.add_argument('--noise', type=str, default='none', help='adding blur to the sample (in the pixel space')

    args = parser.parse_args()

    # data parameters
    args.n_frames = 8
    args.n_channels = 3
    args.n_height = 64
    args.n_width = 64

    return args


if __name__ == '__main__':
    # hyperparameters
    args = define_args()
    print(args)

    # set PRNG seed
    args.device = set_seed_device(args.seed)

    # Dataset path & load test data
    args.dataset_path = '../data/'
    test_data, test_loader = load_dataset(args)

    # load the classifier
    args.checkpoint_name = '../models/sprites_judge.pth'
    classifier = load_classifier(copy.deepcopy(args))

    # create & load model
    model = KoopmanCNN(args).to(device=args.device)
    opt = torch.optim.Adam(model.parameters(), args.lr)

    args.checkpoint_name = "../models/sprites_our.model"
    model, opt = load_checkpoint(model, opt, args.checkpoint_name)
    model.eval()

    print('\nFactorial swap of hair|skin color')

    # this is a useful lambda to eval the accuracy of a specific factor
    err = lambda accs, j: np.linalg.norm(accs[np.delete(list(range(len(accs))), j)]) + 1/accs[j]
    accs_hair, accs_skin = [], []
    for batch in test_loader:
        X, A_labels, D_labels = reorder(batch['images']), batch['A_label'], batch['D_label']
        bsz, fsz = len(X), args.n_frames

        with torch.no_grad():
            outputs = model(X.to(args.device))
            X_dec, C, Z = outputs[0], t_to_np(outputs[-1]), t_to_np(outputs[2].squeeze().reshape(bsz, fsz, -1))

        D, V = np.linalg.eig(C)
        U = np.linalg.inv(V)
        Zp = Z @ V

        I, _, _ = static_dynamic_split(D, args.static_size)
        K = np.random.permutation(bsz)

        # swap only hair color
        J1 = I[:4]
        J2 = I[:6]
        accs_hair1, _, _ = model.factorial_swap(classifier, X_dec, Zp, K, J1, U)
        accs_hair2, _, _ = model.factorial_swap(classifier, X_dec, Zp, K, J2, U)
        if err(np.array(accs_hair1), 4) < err(np.array(accs_hair2), 4):
            J = np.arange(4)
            accs_hair.append(accs_hair1)
        else:
            J = np.arange(6)
            accs_hair.append(accs_hair2)

        # swap only skin color
        J1 = np.delete(I[:8], J)
        J2 = np.delete(I[:10], J)
        accs_skin1, _, _ = model.factorial_swap(classifier, X_dec, Zp, K, J1, U)
        accs_skin2, _, _ = model.factorial_swap(classifier, X_dec, Zp, K, J2, U)
        if err(np.array(accs_skin1), 1) < err(np.array(accs_skin2), 1):
            accs_skin.append(accs_skin1)
        else:
            accs_skin.append(accs_skin2)

    # summarize results
    accs_hair = np.array(accs_hair).mean(axis=0) * 100
    accs_skin = np.array(accs_skin).mean(axis=0) * 100

    print('action {:.2f}, skin {:.2f}, pant {:.2f}, top {:.2f}, hair {:.2f}'.
          format(accs_hair[0], accs_hair[1], accs_hair[2], accs_hair[3], accs_hair[4]))
    print('action {:.2f}, skin {:.2f}, pant {:.2f}, top {:.2f}, hair {:.2f}'.
          format(accs_skin[0], accs_skin[1], accs_skin[2], accs_skin[3], accs_skin[4]))
