import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")
import argparse
import numpy as np
import os.path as osp
import os
import jax
import time
import multiprocessing as mp

from hv.train_utils import seed_all
from hv.utils import flatten
from hv.models.xmap import get_sample, load_ckpt
from hv.data_pyt import Data


def worker(queue):
    while True:
        data = queue.get()
        if data is None:
            break
        i, s, r = data
        s = s.reshape(-1, *s.shape[2:])
        s = s * 0.5 + 0.5
        s = (s * 255).astype(np.uint8)
        r = r.reshape(-1, *r.shape[2:])
        r = r * 0.5 + 0.5
        r = (r * 255).astype(np.uint8)

        if args.no_context:
            s = s[: args.open_loop_ctx]
            r = r[:, args.open_loop_ctx]
        else:
            s[:, :args.open_loop_ctx] = r[:, :args.open_loop_ctx]

        folder = osp.join(args.ckpt, 'samples')
        if args.include_actions:
            folder += '_action'
        folder += f'_{args.open_loop_ctx}'
        os.makedirs(folder, exist_ok=True)
        np.savez_compressed(osp.join(folder, f'data_{i}.npz'), real=r, fake=s)


MAX_BATCH = 16
def main(args):
    seed_all(args.seed)
    global MAX_BATCH
    MAX_BATCH = min(MAX_BATCH, args.batch_size)

    kwargs = dict()
    if args.batch_size is not None:
        #kwargs['batch_size'] = args.batch_size
        kwargs['batch_size'] = MAX_BATCH
    if args.open_loop_ctx is not None:
        kwargs['open_loop_ctx'] = args.open_loop_ctx
    
    model, state, config = load_ckpt(args.ckpt, return_config=True, 
                                     **kwargs, data_path=args.data_path)

    print(config)

    if args.include_actions:
        assert config.use_actions

    if config.use_actions and not args.include_actions:
        assert config.dropout_actions

    old_seq_len = config.seq_len
    config.seq_len = args.seq_len
    config.eval_seq_len = args.seq_len
    data = Data(config)
    loader = data.test_dataloader()
    loader = iter(loader)
    config.seq_len = old_seq_len

    sample = get_sample(config)

    queue = mp.Queue()
    procs = [mp.Process(target=worker, args=(queue,)) for _ in range(1)]
    [p.start() for p in procs]
    
    start = time.time()
    B = MAX_BATCH // jax.local_device_count()
    idx = 0
    assert args.n_repeat == 1
    for _ in range(args.n_repeat):
        for i in range(0, args.batch_size // jax.local_device_count(), B):
            batch = next(loader)
            batch = {k: np.reshape(v.numpy(), (jax.local_device_count(), -1, *v.shape[1:]))
                     for k, v in batch.items()}
            print(batch['video'].shape)
            if 'actions' not in batch:
                batch['actions'] = None

            #v_in = batch['video'][:, i:i+B]
            #act_in = batch['actions'][:, i:i+B] if batch['actions'] is not None else None
            v_in = batch['video']
            act_in = batch['actions'] if batch['actions'] is not None else None

            if config.use_actions and not args.include_actions:
                act_in = np.full_like(act_in, -1)
            s,r  = sample(model, state, v_in, act_in, seed=args.seed, log_output=True)
            queue.put((idx, s, r))
            idx += 1
    [queue.put(None) for _ in range(4)]
    print('sampling', time.time() - start)

    [p.join() for p in procs]

    folder = osp.join(args.ckpt, 'samples')
    if args.include_actions:
        folder += '_action'
    folder += f'_{args.open_loop_ctx}'
    print('Saved to', folder)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--ckpt', type=str, required=True)
    parser.add_argument('-d', '--data_path', type=str, required=True)
    parser.add_argument('-n', '--batch_size', type=int, default=32)
    parser.add_argument('-r', '--n_repeat', type=int, default=1)
    parser.add_argument('-l', '--seq_len', type=int, default=None)
    parser.add_argument('-o', '--open_loop_ctx', type=int, default=None)
    parser.add_argument('-a', '--include_actions', action='store_true')
    parser.add_argument('-s', '--seed', type=int, default=0)
    parser.add_argument('--no_context', action='store_true')
    args = parser.parse_args()

    main(args)
