"""
Run inference with a trained SV2P model.
"""
from GNS.fabric_vsf.vismpc.SV2P import SV2P
from dotmap import DotMap
import pickle
import numpy as np
import sys
import argparse
import cv2
import os
from chester import logger

def get_default_args():
    pp = argparse.ArgumentParser()
    pp.add_argument("--data_dir", type=str, default="./GNS/fabric_vsf/data/train_data")
    pp.add_argument("--model_dir", type=str, default="./GNS/fabric_vsf/data/output_data")
    pp.add_argument("--horizon", type=int, default=1)
    pp.add_argument("--input_img", type=str, default='./GNS/fabric_vsf/data/fold_test_data-8.pkl', help="filepath of input image as .pkl file")
    pp.add_argument("--batch", action="store_true", default=True) # input_img is pkl of images and actions
    args = pp.parse_args([])
    return args

def run_task(vv, log_dir, exp_name):
    args = get_default_args()
    args.__dict__.update(**vv)

    logger.configure(dir=log_dir, exp_name=exp_name)
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)

    params = DotMap()
    params.name = 'cloth'
    params.model_dir = args.model_dir
    params.data_dir = args.data_dir
    params.popsize = 1
    params.nparts = 1
    params.plan_hor = args.horizon
    params.adim = vv['adim']
    params.stochastic_model = True
    sys.argv = sys.argv[:1]
    sv2p = SV2P(params)
    if args.batch:
        # pkl is a list of length num_episodes with entries of the form {obs: (steps+1, 56, 56, 3), acts: (steps, 4)}
        # output is of the same form with entries {preds: (steps+1-horizon, horizon, 56, 56, 3), acts: (steps+1-horizon, horizon, 4)}
        print("input_image: ", args.input_img)
        pkl = pickle.load(open(args.input_img, 'rb'))
        # pkl = torch.load(args.input_img)
        # obs, act = pkl
        # print(pkl[0], pkl[1])
        # print(pkl[0].shape, pkl[1].shape, flush=True)

        output = list()
        num = 0
        for ep_idx, (episode_obs, episode_act) in enumerate(zip(*pkl)):
            print(ep_idx)
            num_steps = len(episode_act)
            all_acts = np.array(episode_act)
            # print(all_acts, flush=True)
            preds = []
            acts = []
            for i in range(num_steps + 1 - args.horizon):
                # print(episode_obs[i].shape,  flush=True)
                # cv2.imshow("cur image", episode_obs[i])
                # cv2.waitKey()
                currim = episode_obs[i].astype(np.float32)
                
                curracts = all_acts[i:i + args.horizon][np.newaxis,:].astype(np.float32)
                pred = sv2p.predict(currim, curracts)[0][0]
                nextim = episode_obs[i+1].astype(np.uint8)
                nextim_pred = pred[0].astype(np.uint8)

                combined_next_and_predict = np.concatenate([nextim, nextim_pred], axis=1)
                # cv2.imshow("next and predict", combined_next_and_predict)
                # cv2.waitKey()
                path = os.path.join(logger.get_dir(), '{}_{}.png'.format(ep_idx, i))
                # print("right before image write to: ", path)
                cv2.imwrite(path, combined_next_and_predict)
                # print("right after image write", flush=True)

                # print("pred shape: ", pred.shape)
                # for pred_idx, pred_img in enumerate(pred):
                #     cv2.imshow('pred_{}'.format(pred_idx), pred_img.astype(np.uint8))
                #     cv2.waitKey()
                    
                preds.append(pred[args.horizon - 1]) # for VAE data gen just get the last image in the horizon
                acts.append(curracts[0])
            # CAREFUL!! these will not make sense visually unless they are np.uint8, but that can mess with L2 metrics
            # print("pred shape: ", np.array(preds).astype(np.uint8).shape)
            # exit()
            curr_output = {'pred': np.array(preds).astype(np.uint8), 'act': np.array(acts)}
            output.append(curr_output)
            num +=1 
            if num == 5:
                break
        pickle.dump(output, open('predictions.pkl', 'wb'))
    else:
        # input_img is a single image, and we take random actions
        currim = pickle.load(open(args.input_img, 'rb')) # input image
        acts1 = np.random.uniform(-1, 1, (1, 5, 2)) # modify as desired
        acts2 = np.random.uniform(-0.3, 0.3, (1, 5, 2))
        acts = np.dstack((acts1, acts2))
        pred = sv2p.predict(currim, acts)[0][0]
        d = {"act": acts[0], "pred": pred.astype(np.uint8)}
        pickle.dump(d, open('predictions.pkl', 'wb')) # output
    print("Successfully wrote predictions to predictions.pkl")

if __name__ == '__main__':
    run_task(dict(), None, None)