import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torchvision.utils as vutils
import torch.nn.functional as F


if 'models' in sys.modules :
	sys.modules.pop('models')

if 'models.gwcnet' in sys.modules :
	sys.modules.pop('models.gwcnet')

if 'models.loss' in sys.modules :
	sys.modules.pop('models.loss')
	
from models import __models__

import numpy as np

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, level=-1, model = 'gwcnet-g') :

	arguments = {'maxdisp' : 192}

	args = Bunch(arguments)
	
	model = __models__[model](args.maxdisp, allLevels = True)
	model = nn.DataParallel(model).cuda()
	
	checkpoint = torch.load(pretrained)
	model.load_state_dict(checkpoint['model'])
	
	model.eval()
	
	return lambda imgLeft, imgRight : model.forward(imgLeft[np.newaxis,...].cuda(), imgRight[np.newaxis,...].cuda())[level]
