import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torch.backends.cudnn as cudnn


if 'models' in sys.modules :
	sys.modules.pop('models')
	
from models import AnyNet 

import numpy as np

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, level=-1) :

	arguments = {'maxdisp' : 192, 
					'init_channels' : 1, 
					'maxdisplist' : [12, 3, 3], 
					'spn_init_channels' : 8,
					'nblocks': 2,
					'layers_3d' : 4,
					'channels_3d' : 4,
					'growth_rate' : [4,1,1],
					'with_spn' : True 
					}

	args = Bunch(arguments)
	
	model = AnyNet(args)
	model = nn.DataParallel(model).cuda()
	
	checkpoint = torch.load(pretrained)
	model.load_state_dict(checkpoint['state_dict'], strict=False)
	
	model.eval()
	
	return lambda imgLeft, imgRight : model.forward(imgLeft[np.newaxis,...].cuda(), imgRight[np.newaxis,...].cuda())[level]
