

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torch.nn.functional as F
import skimage
import skimage.io
import skimage.transform
import numpy as np
import time
import math
import cv2
import torchgeometry

from submodule import feature_extraction , convbn_3d ,convbn 

class contracting(nn.Module):
    def __init__(self , n_channels=3 ):
        super().__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(n_channels , 64, 3, stride=1, padding=1) , nn.BatchNorm2d(64) , nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1) , nn.BatchNorm2d(64), nn.ReLU())

        self.layer2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1) , nn.BatchNorm2d(128) , nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128) , nn.ReLU())

        self.layer3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1) , nn.BatchNorm2d(256 ) , nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU())

        self.layer4 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1),  nn.BatchNorm2d(512) , nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU())

        self.layer5 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=1, padding=1) , nn.BatchNorm2d(512) , nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512) , nn.ReLU())

        self.down_sample = nn.MaxPool2d(2, stride=2)


    def forward(self, X):
        X1 = self.layer1(X)
        X2 = self.layer2(self.down_sample(X1))
        X3 = self.layer3(self.down_sample(X2))
        X4 = self.layer4(self.down_sample(X3))
        X5 = self.layer5(self.down_sample(X4))
        return X5, X4, X3, X2, X1



class expansive(nn.Module):
    def __init__(self , inp_shape=128 , n_channels=3 ):
        super().__init__()
        
        self.inp_shape = inp_shape  # width of the image 
        

        self.layer1 = nn.Conv2d(128, n_channels  , 3, stride=1, padding=1)

        self.layer2 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=1, padding=1),
                                    nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1),
                                    nn.ReLU())

        self.layer3 = nn.Sequential(nn.Conv2d(256, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU())

        self.layer4 = nn.Sequential(nn.Conv2d(512, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU())

        self.layer5 = nn.Sequential(nn.Conv2d(1024, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU())

        self.up_sample_54 = nn.ConvTranspose2d(512, 512, 2, stride=2)

        self.up_sample_43 = nn.ConvTranspose2d(512, 256, 2, stride=2, padding=0)

        self.up_sample_32 = nn.ConvTranspose2d(256, 128, 2, stride=2)

        self.up_sample_21 = nn.ConvTranspose2d(128, 64, 2, stride=2)


    def forward(self, X5, X4, X3, X2, X1):
        X = self.up_sample_54(X5)
        X4 = torch.cat([X, X4], dim=1)
        X4 = self.layer5(X4)

        X = self.up_sample_43(X4)
        
        if self.inp_shape == 100:
            X = F.pad(X, (0,1,0,1), mode='replicate')

        X3 = torch.cat([X, X3], dim=1)
        X3 = self.layer4(X3)

        X = self.up_sample_32(X3)
        X2 = torch.cat([X, X2], dim=1)
        X2 = self.layer3(X2)

        X = self.up_sample_21(X2)
        X1 = torch.cat([X, X1], dim=1)
        X1 = self.layer2(X1)

        X = self.layer1(X1)

        return X


class unet(nn.Module):
    def __init__(self , inp_shape=128, n_channels=3 ):
        super().__init__()
        self.down = contracting( n_channels=n_channels)
        self.up = expansive(inp_shape=inp_shape , n_channels=n_channels)

    def forward(self, X):
        X5, X4, X3, X2, X1 = self.down(X)
        X = self.up(X5, X4, X3, X2, X1)
        return X


    

mapping_cache = {}


def get_grid_one( cam_conf , img_h , img_w , n_hmap , xmax , xmin , ymax , ymin  , max_disp , camera_ext_x , camera_ext_y   ):
    remap_normed_inv = np.zeros((n_hmap , n_hmap , 2 ))
    assert len(cam_conf) == 4 
    f , cx , cy , tx = cam_conf
    f = float( f )
    cx = float( cx )
    cy = float( cy )
    tx = float( tx )
    
    
    key = str(f) + str(cx) + str(cy) + str(tx)
    
    if not key in  mapping_cache:

        for X in range(n_hmap):
            for Y in range(n_hmap):
                # x: 
                k = ((( f  / (((xmax-xmin)*X/n_hmap + xmin - camera_ext_x)/tx ) ))) / ( max_disp/2) - 1 
                # y:
                j = ((( f  / (((xmax-xmin)*X/n_hmap + xmin -camera_ext_x )/tx ) )*(((ymax-ymin)*Y/n_hmap + ymin - camera_ext_y )/tx) + cx)/(img_w/2) ) - 1 

                remap_normed_inv[ Y ,X, 0 ] = k # depth is along x lol
                remap_normed_inv[ Y , X  , 1 ] = j

        mapping_cache[key] = remap_normed_inv
    
    remap_normed_inv = mapping_cache[key]
    grid = torch.from_numpy( remap_normed_inv[None].astype('float32') )
    return grid
    


def pt_costvol_to_hmap( reduced_vol , cam_confs , sys_confs ):
    
    
    img_h = sys_confs['img_h']
    img_w = sys_confs['img_w']
    n_hmap = sys_confs['n_hmap']
    xmax = sys_confs['xmax']
    xmin = sys_confs['xmin']
    ymax = sys_confs['ymax']
    ymin = sys_confs['ymin']
    max_disp = sys_confs['max_disp']
    camera_ext_x = sys_confs['camera_ext_x']
    camera_ext_y     = sys_confs['camera_ext_y']
    
    
    
    assert reduced_vol.shape[2] == img_w
    assert reduced_vol.shape[3] == max_disp
    
    bs = reduced_vol.shape[0]
    grids = []
    
    for i in range( bs ):
        grids.append( get_grid_one( cam_confs[i] , img_h=img_h , img_w=img_w , n_hmap=n_hmap , xmax=xmax , xmin=xmin , ymax=ymax , ymin=ymin   , max_disp=max_disp , camera_ext_x=camera_ext_x, camera_ext_y=camera_ext_y   ) )
    grid = torch.cat( grids  , 0).cuda()
        
    warped = torch.nn.functional.grid_sample( reduced_vol , grid,padding_mode='zeros') 
    return warped
        
    

    


def warp_p_scale( img , ipm_m , sys_confs  ):
    mm = ipm_m.cpu().numpy()
    m = mm[ : , :9 ].reshape( (-1,3,3) )
    for i in range(img.shape[0]):
        s = mm[i , 10] /  img[i].shape[2]
        m[  i , : , :2 ] *= s 
#         print("scale , " , s  ,  mm[i , 10] , img[i].shape[2] )
    m = Variable( torch.from_numpy(m)).cuda()
    
#     dbg[-1]  = mm
    
    ans =  torchgeometry.warp_perspective( img , m  , dsize=(sys_confs['n_hmap'] , sys_confs['n_hmap'] ))
    ans = torch.flip(ans , (3,))
    return ans.permute(0 , 1 , 3 , 2)


    
class SBEVNet(nn.Module):
    def __init__(self, sys_confs ,  maxdisp=64 , n_classes_seg = 25 , do_predict_hmap=False , do_ipm_rgb=True , do_ipm_feats=True , fixed_cam_confs=False ):
        super(SBEVNet, self).__init__()
        self.maxdisp = maxdisp
        
        assert maxdisp == sys_confs['max_disp']
        bev_size = sys_confs['n_hmap']
        
        self.feature_extraction = feature_extraction()
        
        self.n_classes_seg = n_classes_seg
        self.do_predict_hmap = do_predict_hmap 
        
        self.do_ipm_rgb = do_ipm_rgb 
        self.do_ipm_feats = do_ipm_feats
        
        self.sys_confs = sys_confs 
        self.fixed_cam_confs = fixed_cam_confs 
        
        
        nnn = 256
        
        if do_ipm_rgb:
            nnn += 3  
            
        if do_ipm_feats:
            nnn += 32
            
        
    
        
        self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1),
                                     nn.ReLU(inplace=True),
                                     convbn_3d(32, 32, 3, 1, 1),
                                     nn.ReLU(inplace=True))

        self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                   nn.ReLU(inplace=True),
                                   convbn_3d(32, 32, 3, 1, 1)) 

        self.dres2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                   nn.ReLU(inplace=True),
                                   convbn_3d(32, 32, 3, 1, 1))
 
        self.dres3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                   nn.ReLU(inplace=True),
                                   convbn_3d(32, 32, 3, 1, 1)) 

        self.dres4 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                   nn.ReLU(inplace=True),
                                   convbn_3d(32, 32, 3, 1, 1)) 
        
        
        
        self.dres4_2d = nn.Sequential(convbn(32*48 + 32    , 128 , 3, 1, 1 , 1 ),
                                   nn.ReLU(inplace=True),
                                   convbn(128 , 128 , 3, 1, 1 , 1 )) 
        
        self.dres4_2d_top = nn.Sequential(convbn(32*sys_confs['img_h']//4  ,  256 , 3, 1, 1 , 1 ),
                                   nn.ReLU(inplace=True),
                                   convbn( 256 ,  256 , 3, 1, 1 , 1 )) 
        
        self.dres4_2d_top_2 = nn.Sequential(convbn( 256  ,  256 , 3, 1, 1 , 1 ),
                                   nn.ReLU(inplace=True),
                                   convbn( 256 ,  256 , 3, 1, 1 , 1 )) 
        
        
        
        self.dres_2d_seg_1 = nn.Sequential(convbn( nnn  ,  nnn , 3, 1, 1 , 1 ),
                                   nn.ReLU(inplace=True),
                                   convbn( nnn ,  nnn , 3, 1, 1 , 1 )) 
        
        self.dres_2d_seg_2 = nn.Sequential(convbn( nnn  ,  nnn , 3, 1, 1 , 1 ),
                                   nn.ReLU(inplace=True),
                                   convbn( nnn ,  nnn , 3, 1, 1 , 1 )) 
        
        self.dres_2d_seg_3 = nn.Sequential(convbn( nnn  ,  nnn , 3, 1, 1 , 1 ),
                                   nn.ReLU(inplace=True),
                                   convbn( nnn ,  nnn , 3, 1, 1 , 1 )) 
        
        
 
        self.classify = nn.Sequential(convbn_3d(32, 32, 3, 1, 1  ),
                                      nn.ReLU(inplace=True),
                                      nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))
    
        self.classify_seg = nn.Conv2d(nnn  , n_classes_seg , kernel_size=3, padding=1, stride=1,bias=False) 
        
    
        self.seg_up = nn.UpsamplingBilinear2d(scale_factor=4 )
        
        if self.do_predict_hmap:
            self.reg_hmap = nn.Conv2d( nnn  , 1 , kernel_size=3, padding=1, stride=1,bias=False) 
        
        self.unet = unet(inp_shape=bev_size , n_channels=nnn )




    def forward(self, imgs ):
        
        left = imgs[0]
        right = imgs[1]
        
        iq = 2
        
        if self.do_ipm_rgb:
            img_ipm = imgs[iq]
            iq += 1 
            
        if not self.fixed_cam_confs:
            cam_confs = imgs[iq]
            assert cam_confs.shape[-1] == 4 
            assert len(cam_confs.shape) == 2 
            iq += 1 
        else:
            cam_conf = [self.sys_confs['f'] , self.sys_confs['cx'] , self.sys_confs['cy'] , self.sys_confs['tx']]
            bs = left.shape[0]
            cam_confs = [cam_conf]*bs 
            
        if self.do_ipm_feats: 
            ipm_m = imgs[iq]
            iq += 1 
            assert ipm_m.shape[-1] == 3*3 + 2 
            assert len(ipm_m.shape) == 2 
            
                    
        refimg_fea     = self.feature_extraction(left)
        targetimg_fea  = self.feature_extraction(right)
        
        if self.do_ipm_feats:
            feat_ipm = warp_p_scale( refimg_fea , ipm_m , self.sys_confs  )
        
#         feat_ipm_dbg = warp_p_scale( left[: , : , ::4 , ::4] , ipm_m   )
#         dbg[0] = feat_ipm
#         dbg[1] = feat_ipm_dbg
#         dbg[2]  = img_ipm
        
 
        #matching
        cost = Variable(torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp//4,  refimg_fea.size()[2],  refimg_fea.size()[3]).zero_(), volatile= not self.training).cuda()

        for i in range(self.maxdisp//4):
            if i > 0 :
                cost[:, :refimg_fea.size()[1], i, :,i:]   = refimg_fea[:,:,:,i:]
                cost[:, refimg_fea.size()[1]:, i, :,i:] = targetimg_fea[:,:,:,:-i]
            else:
                cost[:, :refimg_fea.size()[1], i, :,:]   = refimg_fea
                cost[:, refimg_fea.size()[1]:, i, :,:]   = targetimg_fea
        cost = cost.contiguous()

        cost0 = self.dres0(cost)
        cost0 = self.dres1(cost0) + cost0
        cost0 = self.dres2(cost0) + cost0 
        cost0 = self.dres3(cost0) + cost0 
        cost0 = self.dres4(cost0) + cost0
        
        fea = cost0
        fea = fea.permute(0 , 1 , 3 , 4 , 2 )
        fea = fea.contiguous()
        fea = fea.view( fea.size(0) , fea.size(1)*fea.size(2) , fea.size(3) , fea.size(4) )
        
        fea = self.dres4_2d_top( fea )
        fea = self.seg_up( fea )
        fea = self.dres4_2d_top_2( fea )
        
        fea = pt_costvol_to_hmap( fea , cam_confs , sys_confs=self.sys_confs  )


        if self.do_ipm_rgb:
            fea = torch.cat( [ fea ,  img_ipm ] , dim=1 )
        
        if self.do_ipm_feats:
            fea = torch.cat( [ fea , feat_ipm  ] , dim=1 )
    
        fea = self.unet( fea )
        fea = self.dres_2d_seg_1(fea) + fea
        fea = self.dres_2d_seg_2(fea) + fea
        fea = self.dres_2d_seg_3(fea) + fea
        
        if self.do_predict_hmap:
            fea_hmap = self.reg_hmap( fea )
            pred_hmap = F.sigmoid( fea_hmap[: , 0 ] )
        
        fea = self.classify_seg (fea)
        
        
        pred_seg = F.log_softmax(fea)
        
        if self.do_predict_hmap:
            return pred_seg , pred_hmap
        else:
            return pred_seg 
    
