import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import time
import math


if 'networks' in sys.modules :
	sys.modules.pop('networks')
if 'networks.stackhourglass' in sys.modules :
	sys.modules.pop('networks.stackhourglass')
	
from networks.stackhourglass import PSMNet

sys.path.pop(0)

class Bunch(object):
	def __init__(self, adict):
		self.__dict__.update(adict)

def getModel(pretrained, refine = True) :
	
	args = Bunch({'max_disp' : 192, 'lsp_channel' : 4, 'lsp_mode' : 'separate', 'lsp_dilation' : [1, 2, 4, 8], 'refine' : 'csr'})
	
	affinity_settings = {}
	affinity_settings['win_w'] = 3
	affinity_settings['win_h'] = 3
	affinity_settings['dilation'] = args.lsp_dilation
	udc = False
	
	model = PSMNet(maxdisp=args.max_disp, struct_fea_c=args.lsp_channel, fuse_mode=args.lsp_mode,
               affinity_settings=affinity_settings, udc=udc, refine=args.refine, refined_result = refine)
	
	model = nn.DataParallel(model)
	model.cuda()
	
	checkpoint = torch.load(pretrained)
	model.load_state_dict(checkpoint)
	
	model.eval()
	
	def testFunc(imgL, imgR) :
		disparity = model(imgL[np.newaxis,...].cuda(), imgR[np.newaxis,...].cuda(), None)
		return torch.squeeze(disparity)
	
	return testFunc
