import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torchvision.utils as vutils
import torch.nn.functional as F


if 'lib' in sys.modules :
	sys.modules.pop('lib')

from lib.options import BaseOptions
from lib.model import *
from lib.evaluation_utils import *
from lib.utils import *


import numpy as np

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, backbone = 'PSMNet') :

	opt = Bunch({'maxdisp' : 256, 'mindisp' : 0, 'superes_factor' : 1, 'dilation_factor' : 10, 'backbone' : backbone, 'output_representation' : 'bimodal', 'aspect_ratio' : 1, 'no_sine' : False, 'no_residual' : False})
	cuda = torch.device('cuda:0')
	
	model = SMDHead(opt).to(device=cuda)
	
	model.eval()
	
	model.load_state_dict(torch.load(pretrained, map_location=cuda)['state_dict'])
	
	return lambda imgLeft, imgRight : torch.from_numpy(predict(model, cuda, {'left' : imgLeft[np.newaxis, ...], 'right' : imgRight[np.newaxis, ...], 'o_shape' : imgLeft.shape})["pred_disp"])
