import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torchvision.utils as vutils
import torch.nn.functional as F


if 'nets' in sys.modules :
	sys.modules.pop('nets')
	
import nets


import numpy as np

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, level=-1) :

	arguments = {'max_disp' : 192, 
					'num_downsample' : 2,
					'feature_type' : 'aanet',
					'no_feature_mdconv' : False,
					'feature_pyramid' : False,
					'feature_pyramid_network' : True,
					'feature_similarity' : 'correlation',
					'aggregation_type' : 'adaptive',
					'num_scales' : 3,
					'num_fusions' : 6,
					'num_stage_blocks' : 1,
					'num_deform_blocks' : 3,
					'no_intermediate_supervision' : False,
					'refinement_type' : 'stereodrnet',
					'mdconv_dilation' : 2,
					'deformable_groups' : 2
					}

	args = Bunch(arguments)
	
	checkpoint = torch.load(pretrained)
	
	model = nets.AANet(args.max_disp,
                       num_downsample=args.num_downsample,
                       feature_type=args.feature_type,
                       no_feature_mdconv=args.no_feature_mdconv,
                       feature_pyramid=args.feature_pyramid,
                       feature_pyramid_network=args.feature_pyramid_network,
                       feature_similarity=args.feature_similarity,
                       aggregation_type=args.aggregation_type,
                       num_scales=args.num_scales,
                       num_fusions=args.num_fusions,
                       num_stage_blocks=args.num_stage_blocks,
                       num_deform_blocks=args.num_deform_blocks,
                       no_intermediate_supervision=args.no_intermediate_supervision,
                       refinement_type=args.refinement_type,
                       mdconv_dilation=args.mdconv_dilation,
                       deformable_groups=args.deformable_groups).cuda()
	
	if 'upt_state_dict' in checkpoint:
		model.load_state_dict(checkpoint['upt_state_dict'])
	elif 'state_dict' in checkpoint:
		model.load_state_dict(checkpoint['state_dict'])
	else :
		model.load_state_dict(checkpoint)
	
	model.eval()
	
	return lambda imgLeft, imgRight : F.interpolate(model.forward(imgLeft[np.newaxis,...].cuda(), imgRight[np.newaxis,...].cuda())[level][np.newaxis,...], size =  (imgLeft.shape[1], imgLeft.shape[2]), mode='bilinear', align_corners=False)
