import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import torch
import megengine as mge
import megengine.functional as F


if 'nets' in sys.modules :
	sys.modules.pop('nets')
	
from nets import Model as CREStereo

import numpy as np

sys.path.pop(0)

class Bunch(object):
	def __init__(self, adict):
		self.__dict__.update(adict)

class callable:
	def __init__(self, model, n_iter):
		self.iter = n_iter
		self.model = model
	
	def __call__(self, imgLeft, imgRight) :
		imgL = np.ascontiguousarray(imgLeft[np.newaxis,...])
		imgR = np.ascontiguousarray(imgRight[np.newaxis,...])
		
		imgL = mge.tensor(imgL).astype("float32")
		imgR = mge.tensor(imgR).astype("float32")

		imgL_dw2 = F.nn.interpolate(
			imgL,
			size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
			mode="bilinear",
			align_corners=True,
		)
		imgR_dw2 = F.nn.interpolate(
			imgR,
			size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
			mode="bilinear",
			align_corners=True,
		)
		
		pred_flow_dw2 = self.model(imgL_dw2, imgR_dw2, iters=self.iter, flow_init=None)
		
		pred_flow = self.model(imgL, imgR, iters=self.iter, flow_init=pred_flow_dw2)
		pred_disp = F.squeeze(pred_flow[:, 0, :, :]).numpy()
		
		return torch.from_numpy(pred_disp.copy())

def getModel(pretrained, n_iter = 20) :
	
	checkpoint = mge.load(pretrained)
	model = CREStereo(max_disp=256, mixed_precision=False, test_mode=True)

	model.load_state_dict(checkpoint["state_dict"], strict=True)

	model.eval()
	
	return callable(model, n_iter)
