
import os
import re
import numpy as np
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from PIL import Image

import cv2


from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, transforms
from . import readpfm as rp

import numpy as np
import argparse
import math
import scipy
import struct
from scipy import integrate
import cv2  
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import os



class SynDataset(Dataset):

    def __init__(self, root_path: str, split = "train"):   # train / val
    
        self.H = 250
        self.W = 400
        self.depth_scale = 1000
        
        
        self.seq_path = os.path.join(root_path, split)  
        self.spike_sequences = os.listdir(self.seq_path) # e.g. 2018-10-16-11-43-02
        
        #self.rec_base_path = os.path.join(root_path, "reconstruction", split)
        
        
        self.depth_list = []
        self.spike_list_left = []
        self.spike_list_right = []
        self.depth_list = []
        #self.
        #self.rec_list_left = []
        #self.rec_list_right = []
        
        
        #seqs = os.listdir(self.spike_sequences)
        for seq in self.spike_sequences:#Eseqs:
            spike_r_path = os.path.join(self.seq_path, seq, "right", "spike", "spike", "r100")   
            right_spikes = os.listdir(spike_r_path) 
            
            for r_s in right_spikes:
                r_s_path = os.path.join(spike_r_path, r_s).replace("lcam", "rcam")
                l_s_path = r_s_path.replace("right","left").replace("rcam", "lcam")
                l_depth = r_s.replace("img", "depth").replace("npy", "png").replace("rcam","lcam")
                l_depth_path = os.path.join(self.seq_path, seq, "left", "depth", l_depth)
                
                self.spike_list_right.append(r_s_path)
                self.spike_list_left.append(l_s_path)
                self.depth_list.append(l_depth_path)
                
                
                
                
                    
        #print(self.spike_list_right)
        self.depth_list.sort()   
        self.spike_list_left.sort() 
        self.spike_list_right.sort()
        #self.rgb_list_left.sort()
        #self.rgb_list_right.sort()
        
        
        self.split = split
        
    def __len__(self):
        return len(self.spike_list_left)

    def __getitem__(self, idx):

        
        spike_path_left = self.spike_list_left[idx]
        spike_path_right = self.spike_list_right[idx]#spike_path_left.replace("left","right")
        #print(spike_path_right)

        if self.split == "train":
            this_sequence = spike_path_left[75:78]  
            this_order = spike_path_left[85: -4]
        elif self.split == "val":
            this_sequence = spike_path_left[73:76]  
            this_order = spike_path_left[83: -4]            

        depth_gt_path = self.depth_list[idx]

        depth_gt = self.get_gt_depth_maps(depth_gt_path)#np.expand_dims(self.get_gt_disparity(depth_gt_path.replace(".npy",".pfm")), 0) 

        maskdown = np.zeros((768,1024), dtype=np.float32)
        maskup = np.zeros((768,1024), dtype=np.float32) #(depth_gt > 1)
        
        maskdown[depth_gt <= 1] = 1
        maskup[depth_gt > 1] = 1
        
        depth_gt1 = depth_gt * maskdown
        
        depth_gt2 = depth_gt * maskup
        depth_gt2 = depth_gt2.astype(np.uint8)     #[maskdown]
        
        d_gt = depth_gt1 + depth_gt2
        
        disp_gt = 51.2 / d_gt 
        #disp_gt = disp_gt.astype(np.uint8)
        disp_gt = torch.FloatTensor(disp_gt)                
        
        depth_gt = torch.FloatTensor(depth_gt)
        
        #print(torch.max(depth_gt), torch.min(depth_gt), depth_gt_path, spike_path_right, spike_path_left)
        #print(depth_gt)

        #disp_gt = depth_gt#.type(torch.int16)
        #disp_gt[disp_gt >256] = 256
        #disp_gt[disp_gt < 0.3] = 0.3

        ##disp_gt = 1 / disp_gt
        
        #depth_gt[depth_gt > 100] = 100.0
        
        depth_gt = depth_gt 
        #depth_gt = (1 / 5.7) * torch.log(depth_gt/1000.0) + 1.0
        depth_gt = depth_gt / 1000.0
       
        spike_mat_left = self.load_np(spike_path_left)#self.analysedat(dat_path)
        
        spike_mat_left = 2*spike_mat_left - 1
        spike_mat_right = self.load_np(spike_path_right)#self.analysedat(dat_path)
        
        spike_mat_right = 2*spike_mat_right - 1

        
        spike_left = torch.FloatTensor(spike_mat_left)
        spike_right =  torch.FloatTensor(spike_mat_right)

        scale = transforms.Compose([
         transforms.Resize([256, 512]),
         ])  

        sample = {}
        sample["left"] = scale(spike_left)
        sample["right"] = scale(spike_right)
        sample["depth"] = depth_gt.squeeze(0)
        sample["disparity"] = disp_gt.squeeze(0)

    
        return sample#spike, depth
        
    def get_gt_depth_maps(self, depth_map_path):
        
        depth_map_gt = np.array(Image.open(depth_map_path).convert("F"), dtype=np.float32) #/ self.depth_scale
        depth_map_gt = ((depth_map_gt/255)**4)*(1000-0.3)+0.3
        
        return depth_map_gt

        
        
    def load_np(self, np_path):
        npy = np.load(np_path).astype(np.uint8)
        return torch.FloatTensor(npy)

    def random_crop(self, spike, rgb, depth, height, width):  # load numpy format # height = 352 width = 1216
        '''
        assert spike.shape[1] >= height
        assert spike.shape[2] >= width
        assert spike.shape[1] == depth.shape[1]
        assert spike.shape[2] == depth.shape[2]
        '''
        #print(spike)
        
        x = random.randint(0, spike.shape[2] - width)
        y = random.randint(0, spike.shape[1] - height)
        spike = spike[:, y:y + height, x:x + width]
        depth = depth[:, y:y + height, x:x + width]
        #rgb = rgb[y:y + height, x:x + width, :]
        return spike, depth
        
    def refine(self, mat):
        mat[mat > 100] = 100
        return mat

        

    def list2img(self, list_, path):
        w = self.W
        h = self.H
        
        path = path.replace("dat", "npy")
        
        #name = os.path.join("/home/Datadisk/SpikeSet/SpikeData/", path[-8:-4]) + ".png"
        name = path.replace(".dat", ".npy")
        
        mat = np.array(list_).reshape(h, w)
        img = Image.fromarray(np.uint8(mat))
        img = img.transpose(Image.FLIP_TOP_BOTTOM)
        img = np.array(img, dtype = np.uint8)
        name = name.replace(".png","")
        
        
        
        return img
        #img.save(name)   #/home/Datadisk/spikedata5622/spiking-2022/outdoor_real_spike/spiking/train/0040/left/      
                          #/home/Datadisk/spikenpya5622/spiking-2022/outdoor_real_spike/spiking/train/0000/left/0000.npy   

    
    def list2np(self, sum, list_, index): # right or left
        w = 1242
        h = 375
        if index % 24 == 0:
            sum = np.zeros((24, h, w))
       
        mat = np.array(list_).reshape(24, h, w)
        

        return mat                 




class SynRGBDataset(Dataset):

    def __init__(self, root_path: str, split = "train"):   # train / val
    
        self.H = 250
        self.W = 400
        self.depth_scale = 1000
        
        
        self.seq_path = os.path.join(root_path, split)  
        self.spike_sequences = os.listdir(self.seq_path) # e.g. 2018-10-16-11-43-02
        
        #self.rec_base_path = os.path.join(root_path, "reconstruction", split)
        
        
        self.depth_list = []
        self.rgb_list_left = []
        self.rgb_list_right = []
        self.depth_list = []
        #self.
        #self.rec_list_left = []
        #self.rec_list_right = []
        
        
        #seqs = os.listdir(self.spike_sequences)
        for seq in self.spike_sequences:#Eseqs:
            rgb_r_path = os.path.join(self.seq_path, seq, "right", "img")   
            right_rgbs = os.listdir(rgb_r_path) 
            
            for r_s in right_rgbs:
                r_s_path = os.path.join(rgb_r_path, r_s).replace("lcam", "rcam")
                l_s_path = r_s_path.replace("right","left").replace("rcam", "lcam")
                l_depth = r_s.replace("img", "depth").replace("rcam","lcam")
                l_depth_path = os.path.join(self.seq_path, seq, "left", "depth", l_depth)
                
                self.rgb_list_right.append(r_s_path)
                self.rgb_list_left.append(l_s_path)
                self.depth_list.append(l_depth_path)
                
        #print(self.spike_list_right)
        self.depth_list.sort()   
        self.rgb_list_left.sort() 
        self.rgb_list_right.sort()

        
        
        self.split = split
        
    def __len__(self):
        return len(self.rgb_list_left)

    def __getitem__(self, idx):

        
        rgb_path_left = self.rgb_list_left[idx]
        rgb_path_right = self.rgb_list_right[idx]#spike_path_left.replace("left","right")
        #print(spike_path_right)

        if self.split == "train":
            this_sequence = rgb_path_left[75:78]  
            this_order = rgb_path_left[85: -4]
        elif self.split == "val":
            this_sequence = rgb_path_left[73:76]  
            this_order = rgb_path_left[83: -4]            

        depth_gt_path = self.depth_list[idx]
        #print(rgb_path_left, depth_gt_path)
        depth_gt = self.get_gt_depth_maps(depth_gt_path)#np.expand_dims(self.get_gt_disparity(depth_gt_path.replace(".npy",".pfm")), 0) 
                
        
        depth_gt = torch.FloatTensor(depth_gt)
        
        dorn = depth_gt#.int()
        print(torch.max(depth_gt), torch.min(depth_gt))
        
        
        dorn[dorn > 100.0] = 100.0#128.0


        disp_gt = depth_gt.type(torch.int16)
        disp_gt[disp_gt >128] = 128
        disp_gt[disp_gt < 1] = 1
        
        disp_gt = 51.2 / disp_gt
        
        #depth_gt[depth_gt > 100] = 100.0
        
        depth_gt = depth_gt 
        depth_gt = 1 / depth_gt # / 1000.0
       
        #spike_mat_left = self.load_np(spike_path_left)#self.analysedat(dat_path)
        rgb_mat_left = self.load_image(rgb_path_left).transpose(2,0,1)
        rgb_mat_right = self.load_image(rgb_path_left).transpose(2,0,1)
        print(rgb_mat_left.shape)

        
        rgb_left = torch.FloatTensor(rgb_mat_left)
        rgb_right = torch.FloatTensor(rgb_mat_right)

        scale = transforms.Compose([
         transforms.Resize([256, 512]),
         ])  

        scale_dorn = transforms.Compose([
         transforms.Resize([385, 513]), 
         ])  
         
        sample = {}
        sample["left"] = scale(rgb_left)
        sample["right"] = scale(rgb_right)
        sample["depth"] = depth_gt.squeeze(0)
        sample["disparity"] = disp_gt.squeeze(0)
        #print(depth_gt.shape, dorn.shape)
        sample["dorn"] = scale_dorn(dorn.unsqueeze(0)).squeeze(0) #(depth_gt.unsqueeze(0)).squeeze(0)##(depth_gt.unsqueeze(0))#

    
        return sample#spike, depth

    def load_image(self, path):
    
        img = Image.open(path).convert('RGB')
        img = np.array(img, dtype = np.float32)
        #print(img.shape)
        
        return img
        
        
    def get_gt_depth_maps(self, depth_map_path):
        
        depth_map_gt = np.array(Image.open(depth_map_path).convert("F"), dtype=np.float32) #/ self.depth_scale
        depth_map_gt = ((depth_map_gt/255)**4)*(1000-0.3)+0.3
        
        return depth_map_gt

        
        
    def load_np(self, np_path):
        npy = np.load(np_path).astype(np.uint8)
        return torch.FloatTensor(npy)

    def random_crop(self, spike, rgb, depth, height, width):  # load numpy format # height = 352 width = 1216
        '''
        assert spike.shape[1] >= height
        assert spike.shape[2] >= width
        assert spike.shape[1] == depth.shape[1]
        assert spike.shape[2] == depth.shape[2]
        '''
        #print(spike)
        
        x = random.randint(0, spike.shape[2] - width)
        y = random.randint(0, spike.shape[1] - height)
        spike = spike[:, y:y + height, x:x + width]
        depth = depth[:, y:y + height, x:x + width]
        #rgb = rgb[y:y + height, x:x + width, :]
        return spike, depth
        
    def refine(self, mat):
        mat[mat > 100] = 100
        return mat

    