import os
import json
from argparse import ArgumentParser

import torch

from model.network import GFNet
from estimation import demo_estimation

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--conf_path", type=str)
    parser.add_argument("--ckpt_path", type=str)
    parser.add_argument("--dataset", type=str)
    args, _ = parser.parse_known_args()

    with open(args.conf_path, 'r') as file:
        conf = json.load(file)   
    training_resolution = (448, 448)
    upsampling_resolution = (560, 560)
    
    model = GFNet(conf=conf,
                  initial_res=training_resolution,
                  upsample_res=upsampling_resolution,
                  symmetric=True,
                  upsample_preds=True,
                  attenuate_cert=True).cuda()
    print(f'initial_res: {model.initial_res}\n')
    print(f'upsample_res: {model.upsample_res}\n')
    print(f'symmetric: {model.symmetric}\n')
    print(f'upsample_preds: {model.upsample_preds}\n')
    print(f'attenuate_cert: {model.attenuate_cert}\n')
    
    states = torch.load(args.ckpt_path)
    model.load_state_dict(states["model"])

    if args.dataset == 'mscoco':
        test_path = 'assets/mscoco'
    elif args.dataset == 'vis_ir':
        test_path = 'assets/vis_ir'
    elif args.dataset == 'googlemap':
        test_path = 'assets/googlemap'
    elif args.dataset == 'virat':
        test_path = 'assets/virat'

    img1_path = os.path.join(test_path, '1.jpg')
    img2_path = os.path.join(test_path, '2.jpg')
    H_s2t_path = os.path.join(test_path, '12.json')
    
    demo_estimation(model, img1_path, img2_path, H_s2t_path)

    