from __future__ import print_function, division
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torchvision.utils as vutils
import torch.nn.functional as F
import numpy as np
import time
from tensorboardX import SummaryWriter

from models import __models__, model_loss
from utils import *
from torch.utils.data import DataLoader
import gc
from PIL import Image
from tqdm import tqdm
import cv2


from PIL import Image

cudnn.benchmark = True

parser = argparse.ArgumentParser(description='Cascade and Fused Cost Volume for Robust Stereo Matching(CFNet)')
parser.add_argument('--model', default='cfnet', help='select a model structure', choices=__models__.keys())
parser.add_argument('--maxdisp', type=int, default=160, help='maximum disparity')

parser.add_argument('--datapath', default='/home/lijianing/kitti/', required=False, help='data path')
parser.add_argument('--trainlist', default='/home/lijianing/depth/CFNet-mod/filenames/kitti15_train.txt', required=False, help='training list')
parser.add_argument('--testlist', default='/home/lijianing/depth/CFNet-mod/filenames/kitti15_val.txt', required=False, help='testing list')

parser.add_argument('--lr', type=float, default=0.001, help='base learning rate')
parser.add_argument('--batch_size', type=int, default=1, help='training batch size')
parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size')
parser.add_argument('--epochs', type=int, default=150, required=False, help='number of epochs to train')
parser.add_argument('--lrepochs', type=str, default='50:5',required=False, help='the epochs to decay lr: the downscale rate')

parser.add_argument('--logdir', default='/home/lijianing/depth/CFNet-mod/logs', required=False, help='the directory to save logs and checkpoints')
parser.add_argument('--loadckpt', help='load the weights from a specific checkpoint')
parser.add_argument('--resume', action='store_true', help='continue training the model')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')

parser.add_argument('--summary_freq', type=int, default=1, help='the frequency of saving summary')
parser.add_argument('--save_freq', type=int, default=1, help='the frequency of saving checkpoint')
# parse arguments, set seeds
args = parser.parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
os.makedirs(args.logdir, exist_ok=True)

# create summary logger
print("creating new summary file")
logger = SummaryWriter(args.logdir)

# dataset, dataloader

# model, optimizer
from models.ugde_real import SpikeFusionet

#model = StereoNet(1, "subtract", 32)

device = torch.device("cuda:{}".format(3))

model = SpikeFusionet(max_disp=128, device = device)
#device = torch.device("cuda")
#model = nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))


model.to(device)


from collections import OrderedDict


def validate_spike_(model, dataloader):
    '''
    state_dict = torch.load('/home/lijianing/depth/MMlogs/256/psmnet/checkpoint_max.ckpt')
    
    new_state_dict = OrderedDict()
    for k, v in state_dict["model"].items():
        name = k[7:] 
        new_state_dict[name] = v

    
    
    model.load_state_dict(new_state_dict)#['model'])
    '''
    model.eval()
    
    TestImgLoader = dataloader
    errors = {"abs_rel_":0, "sq_rel_":0, "rmse_":0, "rmse_log_":0, "a1_":0, "a2_":0, "a3_":0,
    "abs_rel":0, "sq_rel":0, "rmse":0, "rmse_log":0, "a1":0, "a2":0, "a3": 0, "abs_rel_f":0, "sq_rel_f":0, "rmse_f":0, "rmse_log_f":0, "a1_f":0, "a2_f":0,"a3_f":0}
    n = 0
    
    length = len(dataloader)  # / test batch size
    print(length)
    
    #length = 389
    for sample in tqdm(TestImgLoader):

        imgL, imgR, disp_gt, depth_gt = sample['left'], sample['right'], sample['disparity'], sample['depth']

        imgL = imgL.to(device)
        imgR = imgR.to(device)
        disp_gt = disp_gt.to(device)
        depth_gt = depth_gt.to(device)
        

        pred = model(imgL)#["fusion"]
        
        
        depth_ests = pred
        
        #thresh_ster = pred["thresh_ster"]
        
        
        pred_depth = depth_ests.squeeze(1)



        #thresh_ster = np.array(thresh_ster.detach().cpu(), dtype = np.float32)
        

        pred_depth_ = np.array(pred_depth.detach().cpu(), dtype = np.float32)#.squeeze(0)
        depth_gt_ = np.array(depth_gt.detach().cpu(), dtype = np.float32)#.squeeze(0)
        

        #abs_rel_, sq_rel_, rmse_, rmse_log_, a1_, a2_, a3_ = validate_errors(depth_gt_, disp_ests)
        abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 = validate_errors(depth_gt_, pred_depth_) 
        #abs_rel_f, sq_rel_f,rmse_f,rmse_log_f,a1_f,a2_f, a3_f = validate_errors(depth_gt_, pred_fusion)
                  
         
        
        errors["abs_rel"] = errors["abs_rel"] + abs_rel

        errors["rmse"] = errors["rmse"] + rmse

        errors["sq_rel"] = errors["sq_rel"] + sq_rel

        errors["rmse_log"] = errors["rmse_log"] + rmse_log

        errors["a1"] = errors["a1"] + a1

        errors["a2"] = errors["a2"] + a2

        errors["a3"] = errors["a3"] + a3
        
        
        
    errors["abs_rel"] = errors["abs_rel"] / length

    errors["rmse"] = errors["rmse"] / length
    
    errors["sq_rel"] = errors["sq_rel"] / length

    errors["rmse_log"] = errors["rmse_log"] / length

    errors["a1"] = errors["a1"] / length

    errors["a2"] = errors["a2"] / length

    errors["a3"] = errors["a3"] / length

        
    abs_rel = errors["abs_rel"]
    sq_rel = errors["sq_rel"]
    rmse = errors["rmse"]
    rmse_log = errors["rmse_log"]
    a1 = errors["a1"]
    a2 = errors["a2"]
    a3 = errors["a3"]
        
     
         
    print("errors evaluate depth:\n abs_rel: {}, rmse: {}, sq_rel: {}, rmse_log: {}, a1:{}, a2:{}, a3:{}".format(
          abs_rel, rmse, sq_rel, rmse_log, a1, a2, a3))        
        
    return abs_rel        



        
def validate_spike(model, dataloader):

    model.eval()
    
    TestImgLoader = dataloader
    errors = {"abs_rel_":0, "sq_rel_":0, "rmse_":0, "rmse_log_":0, "a1_":0, "a2_":0, "a3_":0,
    "abs_rel":0, "sq_rel":0, "rmse":0, "rmse_log":0, "a1":0, "a2":0, "a3": 0, "abs_rel_f":0, "sq_rel_f":0, "rmse_f":0, "rmse_log_f":0, "a1_f":0, "a2_f":0,"a3_f":0}
    n = 0
    
    length = len(dataloader)  # / test batch size
    print(length)
    
    #length = 389
    for sample in tqdm(TestImgLoader):

        imgL, imgR, disp_gt, depth_gt = sample['left'], sample['right'], sample['disparity'], sample['depth']

        imgL = imgL.to(device)
        imgR = imgR.to(device)
        disp_gt = disp_gt.to(device)
        depth_gt = depth_gt.to(device)
        

        pred = model(imgL, imgR)#["fusion"]
        
        disp_ests = pred["stereo"]
        depth_ests = pred["monocular"]["depth"]#
        fusion_ests = pred["fusion"]
        #thresh_ster = pred["thresh_ster"]
        
        
        pred_depth = depth_ests
        disp_ests = disp_ests     


        #thresh_ster = np.array(thresh_ster.detach().cpu(), dtype = np.float32)
        
        disp_ests = np.array(1/disp_ests[-1].detach().cpu(), dtype = np.float32)#.squeeze(0)
        pred_depth_ = np.array(pred_depth.detach().cpu(), dtype = np.float32)#.squeeze(0)
        depth_gt_ = np.array(depth_gt.detach().cpu(), dtype = np.float32)#.squeeze(0)
        disp_gt = np.array(disp_gt.detach().cpu(), dtype = np.float32)#.squeeze(0)
        fusion_ests = np.array(fusion_ests.detach().cpu(), dtype = np.float32)
        
        C, H, W = depth_ests.size()
        
  
        pred_fusion = fusion_ests    
        
        abs_rel_, sq_rel_, rmse_, rmse_log_, a1_, a2_, a3_ = validate_errors_(depth_gt_, disp_ests)
        abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 = validate_errors(depth_gt_, pred_depth_) 
        abs_rel_f, sq_rel_f,rmse_f,rmse_log_f,a1_f,a2_f, a3_f = validate_errors(depth_gt_, pred_fusion)
                  
         
        
        errors["abs_rel"] = errors["abs_rel"] + abs_rel
        errors["abs_rel_"] = errors["abs_rel_"] + abs_rel_
        errors["abs_rel_f"] = errors["abs_rel_f"] + abs_rel_f
        errors["rmse"] = errors["rmse"] + rmse
        errors["rmse_"] = errors["rmse_"] + rmse_
        errors["rmse_f"] = errors["rmse_f"] + rmse_f
        
        errors["sq_rel"] = errors["sq_rel"] + sq_rel
        errors["sq_rel_"] = errors["sq_rel_"] + sq_rel_
        errors["sq_rel_f"] = errors["sq_rel_f"] + sq_rel_f
        errors["rmse_log"] = errors["rmse_log"] + rmse_log
        errors["rmse_log_"] = errors["rmse_log_"] + rmse_log_
        errors["rmse_log_f"] = errors["rmse_log_f"] + rmse_log_f
        errors["a1"] = errors["a1"] + a1
        errors["a1_"] = errors["a1_"] + a1_
        errors["a1_f"] = errors["a1_f"] + a1_f
        errors["a2"] = errors["a2"] + a2
        errors["a2_"] = errors["a2_"] + a2_
        errors["a2_f"] = errors["a2_f"] + a2_f
        errors["a3"] = errors["a3"] + a3
        errors["a3_"] = errors["a3_"] + a3_
        errors["a3_f"] = errors["a3_f"] + a3_f        
        
        
    errors["abs_rel"] = errors["abs_rel"] / length
    errors["abs_rel_"] = errors["abs_rel_"] / length
    errors["abs_rel_f"] = errors["abs_rel_f"] / length
    errors["rmse"] = errors["rmse"] / length
    errors["rmse_"] = errors["rmse_"] / length
    errors["rmse_f"] = errors["rmse_f"] / length    
    errors["sq_rel"] = errors["sq_rel"] / length
    errors["sq_rel_"] = errors["sq_rel_"] / length
    errors["sq_rel_f"] = errors["sq_rel_f"] / length
    errors["rmse_log"] = errors["rmse_log"] / length
    errors["rmse_log_"] = errors["rmse_log_"] / length
    errors["rmse_log_f"] = errors["rmse_log_f"] / length
    errors["a1"] = errors["a1"] / length
    errors["a1_"] = errors["a1_"] / length
    errors["a1_f"] = errors["a1_f"] / length
    errors["a2"] = errors["a2"] / length
    errors["a2_"] = errors["a2_"] / length
    errors["a2_f"] = errors["a2_f"] / length
    errors["a3"] = errors["a3"] / length
    errors["a3_"] = errors["a3_"] / length
    errors["a3_f"] = errors["a3_f"] / length
        
    abs_rel = errors["abs_rel"]
    sq_rel = errors["sq_rel"]
    rmse = errors["rmse"]
    rmse_log = errors["rmse_log"]
    a1 = errors["a1"]
    a2 = errors["a2"]
    a3 = errors["a3"]
        
    abs_rel_ = errors["abs_rel_"]
    sq_rel_ = errors["sq_rel_"]
    rmse_ = errors["rmse_"]
    rmse_log_ = errors["rmse_log_"]
    a1_ = errors["a1_"]
    a2_ = errors["a2_"]
    a3_ = errors["a3_"]
        
    abs_rel_f = errors["abs_rel_f"]
    sq_rel_f = errors["sq_rel_f"]
    rmse_f = errors["rmse_f"]
    rmse_log_f = errors["rmse_log_f"]
    a1_f = errors["a1_f"]
    a2_f = errors["a2_f"]
    a3_f = errors["a3_f"]         
         
    print("errors evaluate disparity:\n abs_rel: {}, rmse: {}, sq_rel: {}, rmse_log: {}, a1: {}, a2: {}, a3:{} \n errors evaluate depth:\n abs_rel: {}, rmse: {}, sq_rel: {}, rmse_log: {}, a1:{}, a2:{}, a3:{} \n errors evaluate fusion : \n abs_rel: {}, rmse: {}, sq_rel: {}, rmse_log: {}, a1: {}, a2: {}, a3:{}".format(
          abs_rel_, rmse_, sq_rel_, rmse_log_, a1_, a2_, a3_, abs_rel, rmse, sq_rel, rmse_log, a1, a2, a3, abs_rel_f, rmse_f, sq_rel_f,rmse_log_f, a1_f, a2_f, a3_f ))        
        
    return abs_rel_f         



def validate_errors(gt, pred): # for disparity
   

    gt = 20*gt #1/gt
    pred = 20*pred #1/pred
    
    mask = (gt>0) & (gt<=20.0)
    
    pred[ pred>= 20.0] = 20.0
    pred[ pred< 0] = 0
    
    gt[gt>=20.0] = 20.0
    gt[gt<0] = 0
    
    pred = pred[mask]
    gt = gt[mask]


    thresh = np.maximum((gt / pred), (pred / gt))
    a1 = (thresh < 1.25     ).mean()
    a2 = (thresh < 1.25 ** 2).mean()
    a3 = (thresh < 1.25 ** 3).mean()

    rmse = (gt - pred) ** 2
    rmse = np.sqrt(rmse.mean())

    rmse_log = (np.log(gt) - np.log(pred)) ** 2
    rmse_log = np.sqrt(rmse_log.mean())

    abs_rel = np.mean(np.abs(gt - pred) / gt)

    sq_rel = np.mean(((gt - pred) ** 2) / gt)

    return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3       
        
def validate_errors_(gt, pred): # for fusion depth

    gt = 20*gt#1/gt
    
    
    mask = (gt>0.1) & (gt<=20.0)

    
    pred[ pred>= 20.0] = 20.0
    pred[ pred< 0.1] = 0.1
    
    gt[gt>=20.0] = 20.0
    gt[gt<0.1] = 0.1
    
    pred = pred[mask]
    gt = gt[mask]


    thresh = np.maximum((gt / pred), (pred / gt))
    a1 = (thresh < 1.25     ).mean()
    a2 = (thresh < 1.25 ** 2).mean()
    a3 = (thresh < 1.25 ** 3).mean()

    rmse = (gt - pred) ** 2
    rmse = np.sqrt(rmse.mean())

    rmse_log = (np.log(gt) - np.log(pred)) ** 2
    rmse_log = np.sqrt(rmse_log.mean())

    abs_rel = np.mean(np.abs(gt - pred) / gt)

    sq_rel = np.mean(((gt - pred) ** 2) / gt)

    return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3    


if __name__ == '__main__':
    #test()
    #test_spike()
    to_video()
    #gt_video()
    #addvideo()


