from __future__ import print_function

import sys
import argparse
import os
import json


import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim


from train_utils import train_epoch , save_test_preds , is_degenerate
from model_sbevnet import SBEVNet
import dataloader 


# todo : detect crash and restart

parser = argparse.ArgumentParser()
parser.add_argument("--datapath" , type = str   )
parser.add_argument("--save_path" , type = str   )

parser.add_argument("--loadmodel" , type = str , default=None  )
parser.add_argument("--loadmodel_iter" , type = int , default=-1  )
parser.add_argument("--save_freq" , type = int , default=1  )



parser.add_argument("--img_w" , type = int   )
parser.add_argument("--img_h" , type = int   )
parser.add_argument("--max_disp" , type = int  ,default=64  )
parser.add_argument("--n_hmap" , type = int   )

parser.add_argument("--xmin" , type = float   )
parser.add_argument("--xmax" , type = float   )
parser.add_argument("--ymin" , type = float   )
parser.add_argument("--ymax" , type = float   )

parser.add_argument("--tx" , type = float , default=None  )
parser.add_argument("--cx" , type = float  , default=None )
parser.add_argument("--cy" , type = float , default=None  )
parser.add_argument("--f" , type = float  , default=None )


parser.add_argument("--camera_ext_x" , type = float  , default=0.0 )
parser.add_argument("--camera_ext_y" , type = float  , default=0.0 )


parser.add_argument("--epochs" , type = int , default=30    )

parser.add_argument('--only_testing',action='store_true')
parser.add_argument('--sanity_test',action='store_true')
parser.add_argument('--fixed_cam_confs',action='store_true')
parser.add_argument('--do_ipm_rgb',action='store_true')
parser.add_argument('--do_ipm_feats',action='store_true')


args = parser.parse_args()


sys_confs = {
    "img_w":args.img_w  , 
    "img_h" : args.img_h  , 
    "xmin" : args.xmin  ,
    "xmax" : args.xmax  , 
    "ymin" : args.ymin ,
    "ymax" : args.ymax  , 
    "max_disp" : args.max_disp  , 
    
    "cx" : args.cx , 
    "cy" : args.cy , 
    "f" : args.f, 
    "tx" : args.tx  , 
    "camera_ext_x": args.camera_ext_x  , 
    "camera_ext_y": args.camera_ext_y , 
    "n_hmap": args.n_hmap  
    
}

print( sys_confs )

do_predict_hmap=False
epochs = args.epochs 
datapath = args.datapath 
save_path = args.save_path 
only_testing = args.only_testing 
sanity_test = args.sanity_test 
fixed_cam_confs = args.fixed_cam_confs 
do_ipm_rgb = args.do_ipm_rgb 
do_ipm_feats = args.do_ipm_feats 
save_freq = args.save_freq 

loadmodel = args.loadmodel 

if  args.loadmodel_iter > 0:
    loadmodel = save_path +'model_'+str(args.loadmodel_iter)+'.tar'

if sanity_test:
    epochs=3 



def init_model():
    model = SBEVNet( sys_confs ,  maxdisp=64 , n_classes_seg = 25 , do_predict_hmap=do_predict_hmap , do_ipm_rgb=do_ipm_rgb , do_ipm_feats=do_ipm_feats , fixed_cam_confs=fixed_cam_confs  )
    model = nn.DataParallel(model)
    model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))
    return model , optimizer 


model , optimizer  = init_model()



if not loadmodel is None:
    state_dict = torch.load( loadmodel )
    model.load_state_dict(state_dict['state_dict'] , strict=True )
    print("model loaded" , loadmodel )
    
    







jj = json.loads(open( datapath ).read()) 

rootp = os.path.dirname( datapath )

for s in ['train' , 'test']:
    for k in jj[s]:
        jj[s][k] = list( map( lambda x: os.path.join(rootp , x)  ,   jj[s][k]  ))

test_rgb = [ jj['test']["rgb_left"]  ,jj['test']["rgb_right"] ]
train_rgb = [ jj['train']["rgb_left"]  ,jj['train']["rgb_right"] ]
train_left_seg = jj['train']["top_seg"] 
test_left_seg = jj['test']["top_seg"] 

train_mask = jj['train']["mask"] 
test_mask = jj['test']["mask"] 

if do_ipm_rgb:
    train_img_ipm = jj['train']["top_ipm"] 
    test_img_ipm = jj['test']["top_ipm"] 
else:
    train_img_ipm = None 
    test_img_ipm = None 


if do_ipm_feats:
    train_img_ipm_m = jj['train']["top_ipm_m"]
    test_img_ipm_m = jj['test']["top_ipm_m"] 
else:
    train_img_ipm_m = None 
    test_img_ipm_m = None 


if not fixed_cam_confs:
    train_cam_confs = jj['train']['confs']
    test_cam_confs = jj['test']['confs']
else:
    train_cam_confs = None
    test_cam_confs = None
    

if do_predict_hmap:
	train_hmap = jj['train']["hmap"] 
	test_hmap = jj['test']["hmap"]
else:
	train_hmap = None
	test_hmap = None
    
    
    



train_loader = torch.utils.data.DataLoader(
         dataloader.BEVDataLoader(train_rgb , ipm_imgs=train_img_ipm , mask=train_mask,  seg_imgs= train_left_seg , training=True  , th=sys_confs['img_h'] , tw=sys_confs['img_w'], mask_imgs=True, imp_m=train_img_ipm_m   , hmap=train_hmap , hmap_max=500,cam_confs=train_cam_confs ), 
         batch_size= 3, shuffle= True, num_workers= 8, drop_last=False )


train_loader_bs1 = torch.utils.data.DataLoader(
         dataloader.BEVDataLoader(train_rgb , ipm_imgs=train_img_ipm , mask=train_mask,  seg_imgs= train_left_seg , training=True  , th=sys_confs['img_h'] , tw=sys_confs['img_w'], mask_imgs=True, imp_m=train_img_ipm_m   , hmap=train_hmap , hmap_max=500,cam_confs=train_cam_confs ), 
         batch_size= 1 , shuffle= False, num_workers= 8, drop_last=False )



test_loader = torch.utils.data.DataLoader(
         dataloader.BEVDataLoader(test_rgb , ipm_imgs=test_img_ipm , mask=test_mask, seg_imgs=test_left_seg , training=False, th=sys_confs['img_h'] , tw=sys_confs['img_w'] , mask_imgs=True , imp_m=test_img_ipm_m  , hmap=test_hmap , hmap_max=500,cam_confs=test_cam_confs ), 
         batch_size= 1 , shuffle= False, num_workers= 4, drop_last=False )

assert save_path[-1] == "/"




if not only_testing:

    epoch = 1 
    
    while  epoch <  epochs+1:
        epoch_loss  = train_epoch(epoch=epoch , model=model ,data_loader=train_loader , optimizer=optimizer , save_path=save_path , do_predict_hmap=do_predict_hmap , sanity_test=sanity_test )
        
        # restart the optimization if the model performs segnificantly on the train data
        if is_degenerate( model , train_loader_bs1 , n_iter=100 ):
            print("restarting optimization ")
            model , optimizer  = init_model()
            epoch = 1
        else:
            print( "epoch " , epoch , "Mean loss = " , epoch_loss)

            if not save_path is None:
                savefilename = save_path +'model_'+str(epoch)+'.tar'

                if save_freq is None:
                    print("not saving weights")
                else:
                    if epoch%save_freq == 0:
                        torch.save({
                            'epoch': epoch,
                            'state_dict': model.state_dict()
                        }, savefilename)
                        print("saved " , savefilename )
                        
                        save_test_preds(model=model , save_path=save_path , data_loader=test_loader , sanity_test=sanity_test  )
                        
                        
            epoch += 1
    
save_test_preds(model=model , save_path=save_path , data_loader=test_loader , sanity_test=sanity_test  )


