# this script was used to generate qualitative examples of two factor static/dynamic swap
import torch
import numpy as np

import argparse

from koopman_cnn import KoopmanCNN
from utils import t_to_np, reorder, set_seed_device, load_checkpoint, load_dataset, static_dynamic_split, imshow_seq


def define_args():
    parser = argparse.ArgumentParser(description="Sprites Disentanglement")

    # general
    parser.add_argument('--cuda', action='store_false')
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--batch_size', type=int, default=256, metavar='N')

    # model
    parser.add_argument('--conv_dim', type=int, default=32)
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--rnn', type=str, default='both', choices=['enc', 'dec', 'both'])
    parser.add_argument('--hidden_dim', type=int, default=40, help='#feats for LSTM')

    # Koopman parameters
    parser.add_argument('--k_dim', type=int, default=40)
    parser.add_argument('--static_size', type=int, default=7)
    parser.add_argument('--w_rec', type=float, default=15.0)
    parser.add_argument('--w_pred', type=float, default=1.0)
    parser.add_argument('--w_eigs', type=float, default=1.0)
    parser.add_argument('--dynamic_thresh', type=float, default=0.5)

    # other
    parser.add_argument('--noise', type=str, default='none', help='adding blur to the sample (in the pixel space')

    args = parser.parse_args()

    # data parameters
    args.n_frames = 8
    args.n_channels = 3
    args.n_height = 64
    args.n_width = 64

    return args


if __name__ == '__main__':
    # hyperparameters
    args = define_args()
    print(args)

    # set PRNG seed
    args.device = set_seed_device(args.seed)

    # Dataset path & load test data
    args.dataset_path = '../data/'
    test_data, test_loader = load_dataset(args)

    # create & load model
    model = KoopmanCNN(args).to(device=args.device)
    opt = torch.optim.Adam(model.parameters(), args.lr)

    args.checkpoint_name = "../models/sprites_our.model"
    model, opt = load_checkpoint(model, opt, args.checkpoint_name)
    model.eval()

    test_iter = iter(test_loader)
    batch = next(test_iter)

    X = reorder(batch['images'])
    bsz, fsz = len(X), args.n_frames
    with torch.no_grad():
        outputs = model(X.to(args.device))
        X_dec, C, Z = outputs[0], t_to_np(outputs[-1]), t_to_np(outputs[2].squeeze().reshape(bsz, fsz, -1))

    D, V = np.linalg.eig(C)
    U = np.linalg.inv(V)
    Zp = Z @ V

    I = np.random.permutation(bsz)
    J, Jd, Js = static_dynamic_split(D, args.static_size)

    X = t_to_np(X)

    # swap dynamic content
    Xd = t_to_np(model.swap(Zp, I, Jd, U))

    # swap static content
    Xs = t_to_np(model.swap(Zp, I, Js, U))

    # print two samples: source=idx, target=J[idx]
    idx = 0
    X_tmp = np.array(((X[idx], X[I[idx]]), (Xs[idx], Xd[idx])))
    imshow_seq(X_tmp, titles=(('source', 'target'), ('src-dyn + tgt-stat', 'src-stat + tgt-dyn')))
