
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

from models import AnyNet

from dataloader.exrDatasetLoader import exrImagePairDataset

import argparse as args
import collections

import gc
import sys

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

if __name__ == "__main__" : 
	
	parser = args.ArgumentParser(description='Finetune ActiveStereoNet on our dataset')
	
	parser.add_argument("--traindata", help="Path to the training images")
	#parser.add_argument("--validationdata", help="Path to the validation images")
	parser.add_argument("--numepochs", default=10, type=int, help="Number of epochs to run")
	parser.add_argument("--batchsize", default=4, type=int, help="Batch size")
	parser.add_argument("--numworkers", default=4, type=int, help="Number of workers threads used for loading dataset")
	parser.add_argument('--learningrate', default = 5e-5, type=float, help="Learning rate for the optimizer")
	parser.add_argument('--ramcache', action="store_true", help="cache the whole dataset into ram. Do this only if you are certain it can fit.")
	
	parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
	
	parser.add_argument('-p', '--pretrained', default='./models/checkpoint/kitti2015_ck/kitti_2015.tar', help="Pretrained weights")
	parser.add_argument('-o', '--output', default='./models/checkpoint/finetuned_apstereo.pth', help="Trained weights")
	
	args = parser.parse_args()

	arguments = {'maxdisp' : args.maxdisp, 
					'init_channels' : 1, 
					'maxdisplist' : [12, 3, 3], 
					'spn_init_channels' : 8,
					'nblocks': 2,
					'layers_3d' : 4,
					'channels_3d' : 4,
					'growth_rate' : [4,1,1],
					'with_spn' : True 
					}

	argvs = Bunch(arguments)
	
	model = AnyNet(argvs)
	model = nn.DataParallel(model).cuda()
	
	checkpoint = torch.load(args.pretrained)
	model.load_state_dict(checkpoint['state_dict'])
	
	#model = model.module.cpu()
	
	cache = False
	
	dats = exrImagePairDataset(imagedir = args.traindata,
							left_nir_channel = 'Left.SimulatedNir.A', 
							right_nir_channel = 'Right.SimulatedNir.A',
							cache = cache,
							ramcache = args.ramcache,
							direction = 'l2r')
	
	datl = DataLoader(dats, 
					   batch_size= args.batchsize, 
					   shuffle=True, 
					   num_workers=args.numworkers)
	
	def buildOptimizer(parameters) :
		return Adam(parameters, lr=args.learningrate, betas=(0.9, 0.999))
	
	def loss_func(disps, gt, mask) :
		loss_weights = [0.25, 0.5, 1., 1.]
		losses = [loss_weights[x] * F.smooth_l1_loss(disps[x][mask], gt[mask], size_average=True)
                for x in range(4)]
		return sum(losses)
	
	def getLoss() :
		return loss_func
	
	optimizer = buildOptimizer(model.parameters())
	loss = getLoss()
	
	print("ready to train")
	
	for ep in range(args.numepochs) :
		
		
		for batch_id, sampl in enumerate(datl) :
			
			imgLeft = sampl['frameLeft'].cuda()
			imgRight = sampl['frameRight'].cuda()
			imgGtDisp = sampl['trueDisparity'].cuda()
	
			print("Data loaded", end = ' ')
			
			imgLeft = torch.cat((imgLeft, imgLeft, imgLeft), dim=1)
			imgRight = torch.cat((imgRight, imgRight, imgRight), dim=1)
			
			disps = model(imgLeft, imgRight)
			mask = (imgGtDisp < args.maxdisp) & (imgGtDisp > 0)
	
			print("disps and mask computed", end = ' ')
			
			l = loss(disps, imgGtDisp, mask)
				
			optimizer.zero_grad()
					
			l.backward()
			optimizer.step()
			
			lval = l.item()
			
			gc.collect()
			
			print(f"Epoch {ep}, batch {batch_id}: loss = {lval}")
			
	torch.save({"upt_state_dict" : model.state_dict()}, args.output)
