import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import time
import math


if 'models' in sys.modules :
	sys.modules.pop('models')
	
from models.SHRNet import SHRNet

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, iters=32) :
	
	arguments = {'max_disp': 192}
	
	opt = Bunch(arguments)
	
	model = nn.DataParallel(SHRNet(opt.max_disp), device_ids=[0])
	
	checkpoint = torch.load(pretrained)
	model.load_state_dict(checkpoint['state_dict'], strict=False)
	
	model.eval()
	
	def testFunc(imgL, imgR) :
		disparity = model(imgL[np.newaxis,...].cuda(), imgR[np.newaxis,...].cuda())
		return torch.squeeze(disparity)
	
	return testFunc
