import importlib.util

import numpy as np
from scipy.signal import convolve2d

import lexrtools_pyreadExrChannels as rexr
import atiwipy as atw

import os
import glob
import argparse as args
import sys

import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize


models = [
	{
		"name" : "AnyNet",
		"module" : "./AnyNet/callable.py",
		"pretrained_kitti15" : "./AnyNet/models/checkpoint/kitti2015_ck/kitti_2015.tar",
		'isGrayscale' : False
	},
	{
		"name" : "AnyNetLvl1",
		"module" : "./AnyNet/callable.py",
		"pretrained_kitti15" : "./AnyNet/models/checkpoint/kitti2015_ck/kitti_2015.tar",
		'kwargs' : {'level' : 1},
		'isGrayscale' : False
	},
	{
		"name" : "AnyNetLvl0",
		"module" : "./AnyNet/callable.py",
		"pretrained_kitti15" : "./AnyNet/models/checkpoint/kitti2015_ck/kitti_2015.tar",
		'kwargs' : {'level' : 0},
		'isGrayscale' : False
	},
	{
		"name" : "CascadeStereo",
		"module" : "./cascade-stereo/CasStereoNet/callable.py",
		"pretrained_kitti15" : "./cascade-stereo/CasStereoNet/models/casgwcnet.ckpt",
		'isGrayscale' : False
	},
	{
		"name" : "CascadeStereoStage2",
		"module" : "./cascade-stereo/CasStereoNet/callable.py",
		"pretrained_kitti15" : "./cascade-stereo/CasStereoNet/models/casgwcnet.ckpt",
		'kwargs' : {'stage' : 2},
		'isGrayscale' : False
	},
	{
		"name" : "CascadeStereoStage1",
		"module" : "./cascade-stereo/CasStereoNet/callable.py",
		"pretrained_kitti15" : "./cascade-stereo/CasStereoNet/models/casgwcnet.ckpt",
		'kwargs' : {'stage' : 1},
		'isGrayscale' : False
	},
	{
		"name" : "CascadeStereoFt",
		"module" : "./cascade-stereo/CasStereoNet/callable.py",
		"pretrained_activestereo" : "./cascade-stereo/CasStereoNet/models/cas-gwcnet-c-apstereo.pth",
		'isGrayscale' : False
	},
	{
		"name" : "CascadeStereoFtStage2",
		"module" : "./cascade-stereo/CasStereoNet/callable.py",
		"pretrained_activestereo" : "./cascade-stereo/CasStereoNet/models/cas-gwcnet-c-apstereo.pth",
		'kwargs' : {'stage' : 2},
		'isGrayscale' : False
	},
	{
		"name" : "CascadeStereoFtStage1",
		"module" : "./cascade-stereo/CasStereoNet/callable.py",
		"pretrained_activestereo" : "./cascade-stereo/CasStereoNet/models/cas-gwcnet-c-apstereo.pth",
		'kwargs' : {'stage' : 1},
		'isGrayscale' : False
	},
	{
		"name" : "DeepPruner_Best",
		"module" : "./DeepPruner/deeppruner/callable.py",
		"pretrained_kitti15" : "./DeepPruner/deeppruner/pretrain/DeepPruner-best-kitti-bmvc-version.tar",
		'isGrayscale' : False
	},
	{
		"name" : "DeepPruner_Fast",
		"module" : "./DeepPruner/deeppruner/callable.py",
		"pretrained_kitti15" : "./DeepPruner/deeppruner/pretrain/DeepPruner-fast-kitti.tar",
		'isGrayscale' : False,
		'kwargs' : {'arch' : 'fast'},
		'cropsize' : (512,640)
	},
	{
		"name" : "DeepPruner_Best_norefine",
		"module" : "./DeepPruner/deeppruner/callable.py",
		"pretrained_kitti15" : "./DeepPruner/deeppruner/pretrain/DeepPruner-best-kitti-bmvc-version.tar",
		'isGrayscale' : False,
		'kwargs' : {'withRefinement' : False}
	},
	{
		"name" : "DeepPruner_Fast_norefine",
		"module" : "./DeepPruner/deeppruner/callable.py",
		"pretrained_kitti15" : "./DeepPruner/deeppruner/pretrain/DeepPruner-fast-kitti.tar",
		'isGrayscale' : False,
		'kwargs' : {'arch' : 'fast', 'withRefinement' : False},
		'cropsize' : (512,640)
	},
	{
		"name" : "GANet",
		"module" : "./GANet/callable.py",
		"pretrained_kitti15" : "./GANet/pretrained/kitti2015_final.pth",
		"pretrained_kitti12" : "./GANet/pretrained/kitti2012_final.pth",
		'isGrayscale' : False,
		'cropsize' : (480,864)
	},
	{
		"name" : "GANetLvl1",
		"module" : "./GANet/callable.py",
		"pretrained_kitti15" : "./GANet/pretrained/kitti2015_final.pth",
		"pretrained_kitti12" : "./GANet/pretrained/kitti2012_final.pth",
		'isGrayscale' : False,
		'cropsize' : (480,864),
		'kwargs' : {'level' : 1}
	},
	{
		"name" : "GANetLvl0",
		"module" : "./GANet/callable.py",
		"pretrained_kitti15" : "./GANet/pretrained/kitti2015_final.pth",
		"pretrained_kitti12" : "./GANet/pretrained/kitti2012_final.pth",
		'isGrayscale' : False,
		'cropsize' : (480,864),
		'kwargs' : {'level' : 0}
	},
	{
		"name" : "HighResStereo",
		"module" : "./high-res-stereo/callable.py",
		"pretrained_kitti15" : "./high-res-stereo/pretrained/kitti.tar",
		"pretrained_middleburry" : "./high-res-stereo/pretrained/final-768px.tar",
		'isGrayscale' : False,
		'kwargs' : {'level' : 1},
		'cropsize' : (512,640) #'cropsize' : (512,640)
	},
	{
		"name" : "HighResStereoLvl2",
		"module" : "./high-res-stereo/callable.py",
		"pretrained_kitti15" : "./high-res-stereo/pretrained/kitti.tar",
		"pretrained_middleburry" : "./high-res-stereo/pretrained/final-768px.tar",
		'isGrayscale' : False,
		'kwargs' : {'level' : 2},
		'cropsize' : (512,640) #'cropsize' : (512,640)
	},
	{
		"name" : "HighResStereoLvl3",
		"module" : "./high-res-stereo/callable.py",
		"pretrained_kitti15" : "./high-res-stereo/pretrained/kitti.tar",
		"pretrained_middleburry" : "./high-res-stereo/pretrained/final-768px.tar",
		'isGrayscale' : False,
		'kwargs' : {'level' : 3},
		'cropsize' : (512,640) #'cropsize' : (512,640)
	},
	{
		"name" : "PSMNet",
		"module" : "./PSMNet/callable.py",
		"pretrained_kitti15" : "./PSMNet/pretrained/pretrained_model_KITTI2015.tar",
		'isGrayscale' : False
	},
	{
		"name" : "PSMNetLvl1",
		"module" : "./PSMNet/callable.py",
		"pretrained_kitti15" : "./PSMNet/pretrained/pretrained_model_KITTI2015.tar",
		'isGrayscale' : False,
		'kwargs' : {'level' : 1}
	},
	{
		"name" : "PSMNetLvl0",
		"module" : "./PSMNet/callable.py",
		"pretrained_kitti15" : "./PSMNet/pretrained/pretrained_model_KITTI2015.tar",
		'isGrayscale' : False,
		'kwargs' : {'level' : 0}
	},
	{
		"name" : "RealTimeStereo",
		"module" : "./RealtimeStereo/callable.py",
		"pretrained_kitti15" : "./RealtimeStereo/pretrained/pretrained_Kitti2015_realtime.tar",
		'isGrayscale' : False
	},
	{
		"name" : "RealTimeStereoLvl1",
		"module" : "./RealtimeStereo/callable.py",
		"pretrained_kitti15" : "./RealtimeStereo/pretrained/pretrained_Kitti2015_realtime.tar",
		'isGrayscale' : False,
		'kwargs' : {'level' : 1}
	},
	{
		"name" : "RealTimeStereoLvl0",
		"module" : "./RealtimeStereo/callable.py",
		"pretrained_kitti15" : "./RealtimeStereo/pretrained/pretrained_Kitti2015_realtime.tar",
		'isGrayscale' : False,
		'kwargs' : {'level' : 0}
	},
	{
		"name" : "StereoNet",
		"module" : "./StereoNet-ActiveStereoNet/callable.py",
		"pretrained_sceneflow" : "./StereoNet-ActiveStereoNet/pretrained/ps_sceneflow_checkpoint.pth",
		'isGrayscale' : False,
		'kwargs' : {'mode' : 'passive'}
	},
	{
		"name" : "StereoNetNoRefineModule",
		"module" : "./StereoNet-ActiveStereoNet/callable.py",
		"pretrained_sceneflow" : "./StereoNet-ActiveStereoNet/pretrained/ps_sceneflow_checkpoint.pth",
		'isGrayscale' : False,
		'kwargs' : {'mode' : 'passive', 'refineModuleStereoNet' : False}
	},
	{
		"name" : "StereoNetFt",
		"module" : "./StereoNet-ActiveStereoNet/callable.py",
		"pretrained_activestereo" : "./StereoNet-ActiveStereoNet/pretrained/sn_finetuned_sim_stereo.pth",
		'isGrayscale' : False,
		'kwargs' : {'mode' : 'passive'}
	},
	{
		"name" : "StereoNetFtNoRefineModule",
		"module" : "./StereoNet-ActiveStereoNet/callable.py",
		"pretrained_activestereo" : "./StereoNet-ActiveStereoNet/pretrained/sn_finetuned_sim_stereo.pth",
		'isGrayscale' : False,
		'kwargs' : {'mode' : 'passive', 'refineModuleStereoNet' : False}
	},
	{
		"name" : "ActiveStereoNet",
		"module" : "./StereoNet-ActiveStereoNet/callable.py",
		"pretrained_realsense" : "./StereoNet-ActiveStereoNet/pretrained/as_checkpoint.pth",
		'isGrayscale' : True,
		'kwargs' : {'mode' : 'active'}
	},
	{
		"name" : "ActiveStereoNetNoRefineModule",
		"module" : "./StereoNet-ActiveStereoNet/callable.py",
		"pretrained_realsense" : "./StereoNet-ActiveStereoNet/pretrained/as_checkpoint.pth",
		'isGrayscale' : True,
		'kwargs' : {'mode' : 'active', 'refineModuleStereoNet' : False}
	}
]
	
calibration = {'focal' : 50, 'sensorWidth' : 36, 'baseline' : 0.16}
	
def processModel(model, images, color_transform, nir_transform, imgScale = 1.0, outFolder = None) :
	
	print("Processing model: ", model['name'])
	name = model['name']
	
	target = None
	
	if outFolder is not None :
		target = os.path.join(outFolder, model['name'])
		if not os.path.isdir(target) :
			os.mkdir(target)
	
	spec = importlib.util.spec_from_file_location("network.call", 
											   model['module'],
											   submodule_search_locations=os.path.dirname(model['module']))
	callableModule = importlib.util.module_from_spec(spec)
	spec.loader.exec_module(callableModule)
	
	pretrained = None
	
	if 'pretrained_kitti15' in model:
		pretrained = model['pretrained_kitti15']
	elif 'pretrained_kitti12' in model:
		pretrained = model['pretrained_kitti12']
	elif 'pretrained_sceneflow' in model:
		pretrained = model['pretrained_sceneflow']
	elif 'pretrained_middleburry' in model:
		pretrained = model['pretrained_middleburry']
	elif 'pretrained_realsense' in model:
		pretrained = model['pretrained_realsense']
	elif 'pretrained_activestereo' in model:
		pretrained = model['pretrained_activestereo']
	
	if pretrained is None :
		raise ValueError("Missing pretrained model")
	
	needsGrayImages = model['isGrayscale']
	crop = None
	if 'cropsize' in model :
		crop = model['cropsize']
	
	kwargs = {}
	
	if 'kwargs' in model :
		kwargs = model['kwargs']
	
	model = callableModule.getModel(pretrained, **kwargs)
	
	results = {}
	comparisons = {}
	
	for im in images :
		im_name, stats, comps = processFile(im, model, color_transform, nir_transform, imgScale, needsGrayImages, crop, outDir = target)
		results[im_name] = stats
		comparisons[im_name] = comps
		
	return name, results, comparisons
		
	
def processFile(f, 
				model, 
				color_transform, 
				nir_transform, 
				imgScale,
				needsGrayImages, 
				crop = None,
				outDir = None,
				quantize = False) :
	
	print("\tProcessing image file: ", f)
	
	head, fName = os.path.split(f)
	name = fName[:-4]
	
	right_rgb = rexr.readExrLayer(f, 'Right.Color')[:,:,::-1]
	left_rgb = rexr.readExrLayer(f, 'Left.Color')[:,:,::-1]
	
	color_transform.applyTransform(left_rgb)
	color_transform.applyTransform(right_rgb)
	
	right_nir = rexr.readExrChannel(f, 'Right.SimulatedNir.A')
	left_nir = rexr.readExrChannel(f, 'Left.SimulatedNir.A')
	
	left_nir, right_nir = nir_transform(left_nir, right_nir)
	
	#simulate quantization of a 8bit per channel images
	quantizeImage = lambda img : np.round(img*255)/255.
	
	if quantize :
		left_rgb = quantizeImage(left_rgb)
		right_rgb = quantizeImage(right_rgb)
			
		left_nir = quantizeImage(left_nir)
		right_nir = quantizeImage(right_nir)
	
		
	left_rgb *= imgScale
	right_rgb *= imgScale
		
	left_nir *= imgScale
	right_nir *= imgScale
	
	z = rexr.readExrChannel(f, 'Left.Depth.Z')
		
	(h, w) = z.shape

	x_rel_coords = np.arange(w) - (w-1)/2
	y_rel_coords = np.arange(h) - (h-1)/2
		
	dist_center_squared = x_rel_coords[np.newaxis, :]**2 + y_rel_coords[:, np.newaxis]**2
		
	f_pix = calibration['focal']/calibration['sensorWidth']*w
	cosatt = 1 #f_pix/np.sqrt(f_pix**2 + dist_center_squared)

	true_disp = calibration['baseline']*f_pix/np.maximum(cosatt*z, 0.0001)
		
	shape = true_disp.shape
		
	padding = None
	
	if crop is not None :
		dh = int(crop[0] - shape[0])
		dw = int(crop[1] - shape[1])
		
		dh2 = dh//2 if dh%2 == 0 else dh//2+1
		dw2 = dw//2 if dw%2 == 0 else dw//2+1
		
		dh = dh//2
		dw = dw//2
		
		padding = (dw, dw2, dh, dh2)
		
		
	disp_passive = processImgPair(left_rgb, right_rgb, model, needsGrayImages, padding)
	disp_active = processImgPair(left_nir, right_nir, model, needsGrayImages, padding)
		
	if padding is not None :
	
		padding_gt = tuple(p if p < 0 else 0 for p in padding)
		padding_out = tuple(-p if p > 0 else 0 for p in padding)
		
		if np.any(np.array(padding_gt) != 0) :
			true_disp = F.pad(torch.from_numpy(true_disp), padding_gt, mode='constant', value=0).numpy()
			
		if np.any(np.array(padding_out) != 0) :
			disp_passive = F.pad(torch.from_numpy(disp_passive), padding_out, mode='constant', value=0).numpy()
			disp_active = F.pad(torch.from_numpy(disp_active), padding_out, mode='constant', value=0).numpy()
	
	weight = np.array([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]])
	weight = torch.from_numpy(weight)[np.newaxis,np.newaxis,...].type(torch.FloatTensor)
	
	highPass = torch.abs(F.conv2d(torch.from_numpy(true_disp)[np.newaxis,np.newaxis,...], weight, padding='same'))
	edges = (highPass > 2).type(torch.FloatTensor)
	extender = torch.ones(1,1,21,21)
	edges_mask = (F.conv2d(edges, extender, padding='same') > 1).squeeze().numpy()
	flat_mask = np.logical_not(edges_mask)
	
	error_passive = true_disp - disp_passive
	error_active = true_disp - disp_active
	
	if outDir is not None :
		
		attributes = ['File', 'Camera', 'Date']
		
		fileContent = {
			"disp_passive.A" : disp_passive,
			"disp_active.A" : disp_active,
			"error_passive.A" : error_passive,
			"error_active.A" : error_active
		}
		rexr.writeExrFile(os.path.join(outDir, "result_{}.exr".format(name)), fileContent, rexr.getAttributesInfos(f, attributes))
		
		norm = Normalize(vmin=np.floor(np.min(true_disp)), vmax=np.ceil(np.max(true_disp)))
		
		fig = plt.figure("Disparity ground truth", figsize=(5, 3), dpi=300)
		plt.imshow(true_disp, norm=norm)
		plt.xticks([])
		plt.yticks([])
		cbar = plt.colorbar()
		cbar.set_label('disparity [px]', rotation=270, verticalalignment='baseline')
		plt.savefig(os.path.join(outDir, "{}_disparity_gt.pdf".format(name)))
		plt.close(fig)
	
		fig = plt.figure("Disparity passive", figsize=(5, 3), dpi=300)
		plt.imshow(disp_passive, norm=norm)
		plt.xticks([])
		plt.yticks([])
		cbar = plt.colorbar()
		cbar.set_label('disparity [px]', rotation=270, verticalalignment='baseline')
		plt.savefig(os.path.join(outDir, "{}_disparity_passive.pdf".format(name)))
		plt.close(fig)
		
		fig = plt.figure("Disparity active", figsize=(5, 3), dpi=300)
		plt.imshow(disp_active, norm=norm)
		plt.xticks([])
		plt.yticks([])
		cbar = plt.colorbar()
		cbar.set_label('disparity [px]', rotation=270, verticalalignment='baseline')
		plt.savefig(os.path.join(outDir, "{}_disparity_active.pdf".format(name)))
		plt.close(fig)
		
		norm = Normalize(vmin=0, vmax=4)
		
		fig = plt.figure("Error passive", figsize=(5, 3), dpi=300)
		plt.imshow(np.abs(error_passive), norm=norm)
		plt.xticks([])
		plt.yticks([])
		cbar = plt.colorbar()
		cbar.set_label('absolute error [px]', rotation=270, verticalalignment='baseline')
		plt.savefig(os.path.join(outDir, "{}_error_passive.pdf".format(name)))
		plt.close(fig)
		
		fig = plt.figure("Error active", figsize=(5, 3), dpi=300)
		plt.imshow(np.abs(error_active), norm=norm)
		plt.xticks([])
		plt.yticks([])
		cbar = plt.colorbar()
		cbar.set_label('absolute error [px]', rotation=270, verticalalignment='baseline')
		plt.savefig(os.path.join(outDir, "{}_error_active.pdf".format(name)))
		plt.close(fig)
	
	def BadN(error, pix) :
		return np.mean(np.abs(error) > pix)
	
	return name, {
			"MAE-passive" : np.mean(np.abs(error_passive)),
			"MAE-active" : np.mean(np.abs(error_active)),
			"RMSE-passive" : np.sqrt(np.mean(error_passive**2)),
			"RMSE-active" : np.sqrt(np.mean(error_active**2)),
			"BAD0.5-passive" : BadN(error_passive, 0.5),
			"BAD0.5-active" : BadN(error_active, 0.5),
			"BAD1-passive" : BadN(error_passive, 1),
			"BAD1-active" : BadN(error_active, 1),
			"BAD2-passive" : BadN(error_passive, 2),
			"BAD2-active" : BadN(error_active, 2),
			"BAD4-passive" : BadN(error_passive, 4),
			"BAD4-active" : BadN(error_active, 4),
			"MAE_edges-passive" : np.mean(np.abs(error_passive[edges_mask])),
			"MAE_edges-active" : np.mean(np.abs(error_active[edges_mask])),
			"RMSE_edges-passive" : np.sqrt(np.mean(error_passive[edges_mask]**2)),
			"RMSE_edges-active" : np.sqrt(np.mean(error_active[edges_mask]**2)),
			"BAD0.5_edges-passive" : BadN(error_passive[edges_mask], 0.5),
			"BAD0.5_edges-active" : BadN(error_active[edges_mask], 0.5),
			"BAD1_edges-passive" : BadN(error_passive[edges_mask], 1),
			"BAD1_edges-active" : BadN(error_active[edges_mask], 1),
			"BAD2_edges-passive" : BadN(error_passive[edges_mask], 2),
			"BAD2_edges-active" : BadN(error_active[edges_mask], 2),
			"BAD4_edges-passive" : BadN(error_passive[edges_mask], 4),
			"BAD4_edges-active" : BadN(error_active[edges_mask], 4),
			"MAE_flat-passive" : np.mean(np.abs(error_passive[flat_mask])),
			"MAE_flat-active" : np.mean(np.abs(error_active[flat_mask])),
			"RMSE_flat-passive" : np.sqrt(np.mean(error_passive[flat_mask]**2)),
			"RMSE_flat-active" : np.sqrt(np.mean(error_active[flat_mask]**2)),
			"BAD0.5_flat-passive" : BadN(error_passive[flat_mask], 0.5),
			"BAD0.5_flat-active" : BadN(error_active[flat_mask], 0.5),
			"BAD1_flat-passive" : BadN(error_passive[flat_mask], 1),
			"BAD1_flat-active" : BadN(error_active[flat_mask], 1),
			"BAD2_flat-passive" : BadN(error_passive[flat_mask], 2),
			"BAD2_flat-active" : BadN(error_active[flat_mask], 2),
			"BAD4_flat-passive" : BadN(error_passive[flat_mask], 4),
			"BAD4_flat-active" : BadN(error_active[flat_mask], 4)
		}, {
			"Prop-Imp-Pixs" : np.mean(np.abs(error_active) < np.abs(error_passive)),
			"HighFreqERatio" : getHighPassEnergyRatio(disp_passive, disp_active)
		}

def processImgPair(left, right, model, needsGrayImages, padding = None) :
	
	if needsGrayImages :
		if len(left.shape) >= 3 :
			left = np.mean(left, axis=-1)
			right = np.mean(right, axis=-1)
	else :
		if len(left.shape) < 3 :
			left = np.stack((left, left, left))
			right = np.stack((right, right, right))
		else :
			left = np.rollaxis(left,2)
			right = np.rollaxis(right,2)
	
	l = torch.from_numpy(left.copy())
	r = torch.from_numpy(right.copy())
	
	if padding is not None :
		
		l = F.pad(l, padding, mode='constant', value=0)
		r = F.pad(r, padding, mode='constant', value=0)
	
	return torch.squeeze(model(l, r)).cpu().numpy()

def getHighPassEnergyRatio(dispPassive, dispActive) :
	
	kernel = np.array([[-1, -1, -1],
					   [-1, 8, -1],
					   [-1, -1, -1]])
	
	hp_passive = convolve2d(dispPassive, kernel, mode='same', boundary = 'symm')
	hp_active = convolve2d(dispActive, kernel, mode='same', boundary = 'symm')
	
	return np.sum(hp_active**2)/np.sum(hp_passive**2) * np.sum(dispPassive**2)/np.sum(dispActive**2)

def printAggregatedResultsTable(table, outFile = None) :
	
	def writeText(txt) :
		if outFile is None:
			print(txt, end="")
		else :
			outFile.write(txt)
			
	variables = []
	methods = []
	
	maxMethLen = 10
	maxVarLen = 10
	
	for method, variables in table.items() :
		
		if method not in methods:
			methods.append(method)
			maxMethLen = max(maxMethLen, len(method))
			
		for var in variables :
			if var not in variables :
				variables.append(var)
				maxVarLen = max(maxVarLen, len(var))
	
	maxMethLen += 2
	maxVarLen += 2
				
	writeText("Method".ljust(maxMethLen))
	
	for var in variables :
		writeText(f"{var}".rjust(maxVarLen))
	writeText("\n")
	
	for method in methods :
		writeText(f"{method}".ljust(maxMethLen))
		
		for var in variables :
			varValue = '-'
			
			if var in table[method] :
				varValue = f"{table[method][var]:.2f}"
			
			writeText(f"{varValue}".rjust(maxVarLen))
		writeText("\n")
		
	writeText("\n")
	


def printRelativeResultsTable(table, outFile = None) :
	
	def writeText(txt) :
		if outFile is None:
			print(txt, end="")
		else :
			outFile.write(txt)
			
	variables = []
	indicators = []
	methods = []
	
	maxMethLen = 12
	maxVarLen = 12
	maxIndLen = 12
	
	for method, variables in table.items() :
		
		if method not in methods:
			methods.append(method)
			maxMethLen = max(maxMethLen, len(method))
			
		for var, inds in variables.items() :
			if var not in variables :
				variables.append(var)
				maxVarLen = max(maxVarLen, len(var))
				
			for ind in inds :
				if ind not in indicators :
					indicators.append(ind)
					maxIndLen = max(maxIndLen, len(ind))
	
	maxMethLen += 2
	maxColLen = maxVarLen + maxIndLen + 4
				
	writeText("Method".ljust(maxMethLen))
	
	for ind in indicators :
		for var in variables :
			writeText(f"{ind}({var})".rjust(maxColLen))
	writeText("\n")
	
	for method in methods :
		writeText(f"{method}".ljust(maxMethLen))
		
		for ind in indicators :
			for var in variables :
				varIndValue = '-'
				if var in table[method] :
					if ind in table[method][var] :
						varIndValue = f"{table[method][var][ind]:.4f}"
			
				writeText(f"{varIndValue}".rjust(maxColLen))
		writeText("\n")
		
	writeText("\n")
	


def printComparisonResultsTable(table, outFile = None) :
	
	def writeText(txt) :
		if outFile is None:
			print(txt, end="")
		else :
			outFile.write(txt)
			
	variables = []
	aggregates = []
	methods = []
	
	maxMethLen = 10
	maxVarLen = 10
	maxAggLen = 10
	
	for method, variables in table.items() :
		
		if method not in methods:
			methods.append(method)
			maxMethLen = max(maxMethLen, len(method))
			
		for var, aggs in variables.items() :
			if var not in variables :
				variables.append(var)
				maxVarLen = max(maxVarLen, len(var))
				
			for agg in aggs :
				if agg not in aggregates :
					aggregates.append(agg)
					maxAggLen = max(maxAggLen, len(agg))
	
	maxMethLen += 2
	maxColLen = maxVarLen + maxAggLen + 4
				
	writeText("Method".ljust(maxMethLen))
	
	for var in variables :
		for agg in aggregates :
			writeText((f"{var}({agg})").rjust(maxColLen))
	writeText("\n")
	
	for method in methods :
		writeText(f"{method}".ljust(maxMethLen))
		
		for var in variables :
			for agg in aggregates :
				varIndValue = '-'
				if var in table[method] :
					if agg in table[method][var] :
						varIndValue = f"{table[method][var][agg]:.4f}"
			
				writeText(f"{varIndValue}".rjust(maxColLen))
		writeText("\n")
		
	writeText("\n")
	
	
def printFullResultsTable(table, outFile = None) :
	
	def writeText(txt) :
		if outFile is None:
			print(txt, end="")
		else :
			outFile.write(txt)
			
	variables = []
	methods = []
	
	maxMethLen = 10
	maxVarLen = 10
	maxImLen = 10
	
	for method, m_stats in table.items() :
			
		if method not in methods:
			methods.append(method)
			maxMethLen = max(maxMethLen, len(method))
		
		for img, im_variables in m_stats.items() :
			
			maxImLen = max(maxImLen, len(img))
			
			for var in im_variables :
				if var not in variables :
					variables.append(var)
					maxVarLen = max(maxVarLen, len(var))
	
	maxMethLen += 2
	maxVarLen += 2
	maxImLen += 2
				
	writeText("Method".ljust(maxMethLen))
	writeText("Image".ljust(maxImLen))
	
	for var in variables :
		writeText(f"{var}".rjust(maxVarLen))
	writeText("\n")
	
	for method in methods :
		
		for img in table[method] :
			writeText(f"{method}".ljust(maxMethLen))
			writeText(f"{img}".ljust(maxImLen))
			
			for var in variables :
				varValue = '-'
				
				if var in table[method][img] :
					varValue = f"{table[method][img][var]:.2f}"
				
				writeText(f"{varValue}".rjust(maxVarLen))
			writeText("\n")
		writeText("\n")
		
	writeText("\n")
		
		

	
def printFullComparisonTable(table, outFile = None) :
	
	def writeText(txt) :
		if outFile is None:
			print(txt, end="")
		else :
			outFile.write(txt)
			
	variables = []
	methods = []
	
	maxMethLen = 10
	maxVarLen = 10
	maxAggLen = 10
	maxImLen = 10
	
	for method, m_stats in table.items() :
		
		if method not in methods:
			methods.append(method)
			maxMethLen = max(maxMethLen, len(method))
		
		for img, im_variables in m_stats.items() :
			
			maxImLen = max(maxImLen, len(img))
				
			for var, aggs in im_variables.items() :
				if var not in variables :
					variables.append(var)
					maxVarLen = max(maxVarLen, len(var))
	
	maxMethLen += 2
	maxColLen = maxVarLen + maxAggLen + 4
	maxImLen += 2
				
	writeText("Method".ljust(maxMethLen))
	writeText("Image".ljust(maxImLen))
	
	for var in variables :
		writeText(f"{var}".rjust(maxVarLen))
	writeText("\n")
	
	for method in methods :
		
		for img in table[method] :
			writeText(f"{method}".ljust(maxMethLen))
			writeText(f"{img}".ljust(maxImLen))
			
			for var in variables :
				
				varValue = '-'
				
				if var in table[method][img] :
					varValue = f"{table[method][img][var]:.2f}"
					
				writeText(f"{varValue}".rjust(maxVarLen))
						
			writeText("\n")
		writeText("\n")
		
	writeText("\n")
	

if __name__ == "__main__" :
	
	plt.rcParams["font.family"] = "serif"
	
	parser = args.ArgumentParser(description='Test different methods on a given set of images') 
	
	parser.add_argument('inputFolder', help="Path of the folder with the exr images to parse.")
	parser.add_argument('-o', '--outputFolder', default=None, help="Path of the folder where to write the output images (disaprity and error maps are not saved if this is not provided).")
	parser.add_argument('-r', '--resultFile', default="out.dat", help="Path of the fofile into which the results are written.")
	
	parser.add_argument('--ocioconfig', default = '/usr/share/appimages/blender2_93/2.93/datafiles/colormanagement/config.ocio', help="OCIO config used for display")
	parser.add_argument('--ociodisplay', default = 'sRGB', help="OCIO display device to target")
	parser.add_argument('--ocioview', default = 'Filmic', help="OCIO view to use")

	parser.add_argument('--imgscale', type = float, default = 1.0, help="scale factor to apply after color transform (expected either 1, or 255)")
	
	parser.add_argument('--selected', help="limit the computations to the selected models (names separated by commas)")

	args = parser.parse_args()
	
	color_transform = atw.color.OcioColorTransformer(args.ocioconfig, 'Linear', args.ocioview, args.ociodisplay)
	
	outFolder = args.outputFolder
	
	if not os.path.isdir(outFolder) :
		try:
			os.makedirs(outFolder)
		except OSError as error:
			print(error)
			exit(1)
	
	def nir_transform(fl, fr) :
		
		q99l = np.quantile(fl, 0.99)
		q99r = np.quantile(fr, 0.99)
		q99m = (q99l + q99r)/2

		fl[fl > q99m] = q99m
		fr[fr > q99m] = q99m

		fl = fl - np.min(fl)
		fr = fr - np.min(fr)

		fl = fl/np.max(fl)
		fr = fr/np.max(fr)
		
		fl = fl**(1/2.2)
		fr = fr**(1/2.2)
		
		return (fl, fr)
	
	inFiles = glob.glob(os.path.join(args.inputFolder, "*.exr"))
	
	values = dict()
	comparisons = dict()
	
	selected = None
	
	if args.selected is not None :
		selected = {name for name in args.selected.split(',')}
	
	with torch.no_grad() :
		for m in models :
			
			if selected is not None :
				if m["name"] not in selected :
					print("Skipping model ", m["name"])
					continue
			
			method, stats, comps = processModel(m, inFiles, color_transform, nir_transform, imgScale = args.imgscale, outFolder = outFolder)
			values[method] = stats
			comparisons[method] = comps
	
	#compute the aggregates.
	aggregates_passive = dict()
	aggregates_active = dict()
	stats_func = []
	
	for method, m_stats in values.items() :
		
		if method not in aggregates_passive :
			aggregates_passive[method] = dict()
			
		if method not in aggregates_active :
			aggregates_active[method] = dict()
		
		for img, im_stats in m_stats.items() :
			for stat, val in im_stats.items() :
				stat_name, act_vs_pas = stat.split('-')
				
				if stat_name not in stats_func :
					stats_func.append(stat_name)
				
				if act_vs_pas == 'passive' :
				
					if stat_name not in aggregates_passive[method] :
						aggregates_passive[method][stat_name] = []
					
					aggregates_passive[method][stat_name].append(val)
					
				elif act_vs_pas == 'active' :
				
					if stat_name not in aggregates_active[method] :
						aggregates_active[method][stat_name] = []
					
					aggregates_active[method][stat_name].append(val)
		
		for stat_name, lst in aggregates_passive[method].items() :
			aggregates_passive[method][stat_name] = np.mean(lst)
			
		for stat_name, lst in aggregates_active[method].items() :
			aggregates_active[method][stat_name] = np.mean(lst)

	#compute the relative indicators
	
	relative_indicators = dict()
	
	for method, m_stats in values.items() :
		
		if method not in relative_indicators :
			relative_indicators[method] = dict()
		
		for img, im_stats in m_stats.items() :
			for stat, val in im_stats.items() :
				stat_name, act_vs_pas = stat.split('-')
				
				if stat_name not in relative_indicators[method] :
					relative_indicators[method][stat_name] = dict()
					
				if act_vs_pas not in relative_indicators[method][stat_name] :
					relative_indicators[method][stat_name][act_vs_pas] = dict()
					
				relative_indicators[method][stat_name][act_vs_pas][img] = val
		
		for stat_name in stats_func :
			
			passive = relative_indicators[method][stat_name]['passive']
			active = relative_indicators[method][stat_name]['active']
			
			act_sml = 0.
			rel_impr = 0.
			
			intersection = set(passive.keys()).intersection(active.keys())
			
			for img in intersection :
				score_pas = passive[img]
				score_act = active[img]
				
				act_sml += 1. if score_act < score_pas else 0.
				rel_impr += (score_pas - score_act)/(max(score_pas, score_act, 1e-6))
			
			act_sml /= len(intersection)
			rel_impr /= len(intersection)
			
			relative_indicators[method][stat_name] = dict()
			relative_indicators[method][stat_name]['propImprovedImgs'] = act_sml
			relative_indicators[method][stat_name]['averageRelativeImprovement'] = rel_impr

	#compute the aggregate comparisons
	
	aggregate_comps = dict()
	stats_comp = []
	
	for method, m_stats in comparisons.items() :
		
		if method not in aggregate_comps :
			aggregate_comps[method] = dict()
		
		for img, im_stats in m_stats.items() :
			for stat_name, val in im_stats.items() :
				
				if stat_name not in aggregate_comps[method] :
					aggregate_comps[method][stat_name] = []
				
				aggregate_comps[method][stat_name].append(val)
				
				if stat_name not in stats_comp :
					stats_comp.append(stat_name)
					
		for stat_name in stats_comp :
			data = np.array(aggregate_comps[method][stat_name])
			
			ind_mean = np.mean(data)
			ind_max = np.max(data)
			ind_min = np.min(data)
			
			aggregate_comps[method][stat_name] = {
				"mean" : ind_mean,
				"max" : ind_max,
				"min" : ind_min
			}

	with open(args.resultFile, 'w') as rFile :
	
		print("\naggregated results passive: ")
		printAggregatedResultsTable(aggregates_passive)
		
		rFile.write("\naggregated results passive: \n")
		printAggregatedResultsTable(aggregates_passive, outFile = rFile)
		
		
		print("\naggregated results active: ")
		printAggregatedResultsTable(aggregates_active)
		
		rFile.write("\naggregated results active: \n")
		printAggregatedResultsTable(aggregates_active, outFile = rFile)
		
		
		print("\nrelative indicators results: ")
		printRelativeResultsTable(relative_indicators)
		
		rFile.write("\nrelative indicators results: \n")
		printRelativeResultsTable(relative_indicators, outFile = rFile)
		
		
		print("\nimage comparison indicators results: ")
		printComparisonResultsTable(aggregate_comps)
		
		rFile.write("\nimage comparison indicators results: \n")
		printComparisonResultsTable(aggregate_comps, outFile = rFile)
		
		
		rFile.write("\nDetailled results: \n")
		printFullResultsTable(values, outFile = rFile)
		
		rFile.write("\nDetailled comparison operators: \n")
		printFullComparisonTable(comparisons, outFile = rFile)
