import random

import numpy as np
import os
os.sys.path.append('.')
from PIL import Image
from torch.utils.data import Dataset
import math

import torch
import pandas as pd
import dataLoader.utils as utils
import torchvision.transforms.functional as TF
from torchvision import transforms
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms


root_dir = '/backup/dataset/Kitti1'

test_csv_file_name = 'test.csv'
ignore_csv_file_name = 'ignore.csv'
satmap_dir = 'satmap'
grdimage_dir = 'raw_data'
left_color_camera_dir = 'image_02/data'  # 'image_02\\data' #
right_color_camera_dir = 'image_03/data'  # 'image_03\\data' #
oxts_dir = 'oxts/data'  # 'oxts\\data' #

SatCropRange_H = 1200
SatCropRange_W = 1200

GrdImg_H = 256  # 256 # original: 375 #224, 256
GrdImg_W = 1024  # 1024 # original:1242 #1248, 1024
GrdOriImg_H = 375
GrdOriImg_W = 1242
num_thread_workers = 2

train_file = './dataLoader/train_files_noisy.txt'
#train_file = './dataLoader/train_files_with_sat_GPS.txt'
test1_file = './dataLoader/test1_files.txt'
test2_file = './dataLoader/test2_files.txt'
train_files_noisy = './dataLoader/train_files_noisy.txt'

class SatGrdDataset(Dataset):
    def __init__(self, root, file,
                 transform=None, shift_range_lat=20, shift_range_lon=20, rotation_range=10, crop_size=None, rand_crop=False, align_heading=False, test_only=False):
        self.root = root

        self.meter_per_pixel = utils.get_meter_per_pixel(scale=1)
        self.shift_range_meters_lat = shift_range_lat  # in terms of meters
        self.shift_range_meters_lon = shift_range_lon  # in terms of meters
        self.shift_range_pixels_lat = shift_range_lat / self.meter_per_pixel  # shift range is in terms of meters
        self.shift_range_pixels_lon = shift_range_lon / self.meter_per_pixel  # shift range is in terms of meters
        self.crop_size = crop_size
        self.rand_crop = rand_crop
        self.align_heading = align_heading
        # self.shift_range_meters = shift_range  # in terms of meters

        self.rotation_range = rotation_range  # in terms of degree

        self.skip_in_seq = 2  # skip 2 in sequence: 6,3,1~
        if transform != None:
            self.satmap_transform = transform[0]
            self.grdimage_transform = transform[1]

        self.pro_grdimage_dir = 'raw_data'

        self.satmap_dir = satmap_dir

        # with open(file, 'r') as f:
        #     lines = f.readlines()
        #     #files = [line.split(' ') for line in lines]

        # self.file_name = [file[0] for file in files]
        # self.lat = [file[1] for file in files]
        # self.lon = [file[2][:-1] for file in files]1
        if test_only:
            with open(test1_file, 'r') as f:
                lines = f.readlines()
                test1 = [line.split(' ') for line in lines]
            test1_files =  [file[0] for file in test1]
            self.file_name = test1_files
        else:
            with open(train_file, 'r') as f:
                noisy = f.readlines()
            train_noisy = [file[:-1] for file in noisy]
            self.file_name = train_noisy
        # with open(test1_file, 'r') as f:
        #     lines = f.readlines()
        #     test1 = [line.split(' ') for line in lines]
        # test1_files =  [file[0] for file in test1]
                
        # with open(test2_file, 'r') as f:
        #     lines = f.readlines()
        #     test2 = [line.split(' ') for line in lines]
        # test2_files = [file[0] for file in test2]
        
         #test1_files + test2_files + 
        
        
    def __len__(self):
        return len(self.file_name)

    def get_file_list(self):
        return self.file_name
    
    #normalize x and y grid values between -1 and 1
    def normalize_xy(self, x, y):
        norm_x = (x - utils.min_utm_x) / (utils.max_utm_x - utils.min_utm_x) * 2.0 - 1.0
        norm_y = (y - utils.min_utm_y) / (utils.max_utm_y - utils.min_utm_y) * 2.0 - 1.0

        return norm_x, norm_y
    
    
    def __getitem__(self, idx):
        # read cemera k matrix from camera calibration files, day_dir is first 10 chat of file name

        file_name = self.file_name[idx]
        #lat = self.lat[idx]
        #lon = self.lon[idx]
        day_dir = file_name[:10]
        drive_dir = file_name[:38]
        image_no = file_name[38:]

        # =================== read camera intrinsice for left and right cameras ====================
        calib_file_name = os.path.join(self.root, grdimage_dir, day_dir, 'calib_cam_to_cam.txt')
        with open(calib_file_name, 'r') as f:
            lines = f.readlines()
            for line in lines:
                # left color camera k matrix
                if 'P_rect_02' in line:
                    # get 3*3 matrix from P_rect_**:
                    items = line.split(':')
                    valus = items[1].strip().split(' ')
                    fx = float(valus[0]) * GrdImg_W / GrdOriImg_W
                    cx = float(valus[2]) * GrdImg_W / GrdOriImg_W
                    fy = float(valus[5]) * GrdImg_H / GrdOriImg_H
                    cy = float(valus[6]) * GrdImg_H / GrdOriImg_H
                    left_camera_k = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
                    left_camera_k = torch.from_numpy(np.asarray(left_camera_k, dtype=np.float32))
                    # if not self.stereo:print(left_camera_k)
                    break

        # =================== initialize some required variables ============================
        grd_left_imgs = torch.tensor([])
        image_no = file_name[38:]
        reference_oxts_file_no= str(int(image_no[:-4]) // 20).zfill(10)
        reference_file_name =  os.path.join(self.root, grdimage_dir, drive_dir, oxts_dir,
                                      reference_oxts_file_no + ".txt")
        with open(reference_file_name, 'r') as f:
            content = f.readline().split(' ')
            lat0 = float(content[0])
            
        # oxt: such as 0000000000.txt
        oxts_file_name = os.path.join(self.root, grdimage_dir, drive_dir, oxts_dir,
                                      image_no.lower().replace('.png', '.txt'))
        with open(oxts_file_name, 'r') as f:
            content = f.readline().split(' ')
            # get heading
            heading = float(content[5])
            heading = torch.from_numpy(np.asarray(heading))
                
            # reference_oxts_file_no 
            utm_center_x, utm_center_y = utils.gps2utm(float(content[0]), float(content[1]), lat0)

            left_img_name = os.path.join(self.root, self.pro_grdimage_dir, drive_dir, left_color_camera_dir,
                                            image_no.lower())
            with Image.open(left_img_name, 'r') as GrdImg:
                grd_img_left = GrdImg.convert('RGB')
                if self.grdimage_transform is not None:
                    grd_img_left = self.grdimage_transform(grd_img_left)
                    
                    
        # =================== read satellite map ===================================
        SatMap_name = os.path.join(self.root, self.satmap_dir, file_name)
        with Image.open(SatMap_name, 'r') as SatMap:
            sat_map = SatMap.convert('RGB')
            if self.align_heading:
                sat_rot = sat_map.rotate(-heading / np.pi * 180)
                sat_align_cam = sat_rot.transform(sat_rot.size, Image.AFFINE,                                                
                                                (1, 0, utils.CameraGPS_shift_left[0] / self.meter_per_pixel,
                                                0, 1, utils.CameraGPS_shift_left[1] / self.meter_per_pixel),
                                                resample=Image.BILINEAR)
            else: 
                sat_align_cam = sat_map.transform(sat_map.size, Image.AFFINE,                                                
                                                (1, 0, utils.CameraGPS_shift_left[0] / self.meter_per_pixel,
                                                0, 1, utils.CameraGPS_shift_left[1] / self.meter_per_pixel),
                                                resample=Image.BILINEAR)
            
            sat_map = TF.center_crop(sat_map, SatCropRange_H)
        #randomly generate shift
        if self.rand_crop:
            shift_x = np.random.uniform(0, 1)  # --> right as positive, parallel to the heading direction
            shift_y = np.random.uniform(0, 1)  # --> up as positive, vertical to the heading direction
            top = int(shift_y * (SatCropRange_H - self.crop_size))
            left = int(shift_x * (SatCropRange_W - self.crop_size))
            
            #start_x = center_x - mpp * (W/2 - left) + mpp * 0.5
            #start_y = center_y + mpp * (W/2 - top) - mpp * 0.5
            start_x = utm_center_x - self.meter_per_pixel * (SatCropRange_W/2 - left - 0.5)
            start_y = utm_center_y + self.meter_per_pixel * (SatCropRange_H/2 - top - 0.5)
            end_x = start_x + self.meter_per_pixel * (self.crop_size - 1)
            end_y = start_y - self.meter_per_pixel * (self.crop_size - 1)
            
            sat_map = TF.crop(sat_align_cam, top, left, self.crop_size, self.crop_size)
        else:
            start_x = utm_center_x - self.meter_per_pixel * (self.crop_size//2 - 0.5)
            start_y = utm_center_y + self.meter_per_pixel * (self.crop_size//2 - 0.5)
            end_x = start_x + self.meter_per_pixel * (self.crop_size - 1)
            end_y = start_y - self.meter_per_pixel * (self.crop_size - 1)
        
            sat_map = TF.center_crop(sat_align_cam, self.crop_size)

        sat_map = np.array(sat_map, dtype=np.float32)
        
        position_grid_size = int(self.crop_size // 8)
        start_x, start_y = self.normalize_xy(start_x, start_y)
        end_x, end_y = self.normalize_xy(end_x, end_y)

        # print(f"{start_x}, {start_y}")
        # print(f"{end_x}, {end_y}")
        # torch.set_printoptions(precision=6)
        xs = torch.linspace(start=start_x, end=end_x, steps=position_grid_size, dtype=torch.float64)
        ys = torch.linspace(start=start_y, end=end_y, steps=position_grid_size, dtype=torch.float64)
        position_grid = torch.stack(torch.meshgrid(xs, ys, indexing='xy'))
         
        # print(position_grid)
        norm_center_x, norm_center_y = self.normalize_xy(utm_center_x, utm_center_y)
        grid_info = {"center_x": utm_center_x, "center_y": utm_center_y, "norm_center_x": norm_center_x, "norm_center_y": norm_center_y, 
                     "start_x": start_x, "start_y": start_y, "end_x": end_x, "end_y": end_y, "grid_size": position_grid_size, "meter_per_pixel":self.meter_per_pixel, 
                     "left_camera_k": left_camera_k, "heading": heading}
        # assert False

        # randomly generate roation
        #theta = np.random.uniform(-1, 1)
        #sat_rand_shift_rand_rot = \
        #    sat_rand_shift.rotate(theta * self.rotation_range)

        #sat_map =TF.center_crop(sat_rand_shift_rand_rot, utils.SatMap_process_sidelength)
        # sat_map = np.array(sat_map, dtype=np.float32)

        # transform
        #if self.satmap_transform is not None:
        #    sat_map = self.satmap_transform(sat_map)
        
        return sat_map, grd_img_left, position_grid, grid_info
        # return sat_map, left_camera_k, grd_left_imgs[0], \
        #        torch.tensor(-gt_shift_x, dtype=torch.float32).reshape(1), \
        #        torch.tensor(-gt_shift_y, dtype=torch.float32).reshape(1), \
        #        torch.tensor(theta, dtype=torch.float32).reshape(1), \
        #        file_name
        
        
def load_train_data(batch_size, shift_range_lat=20, shift_range_lon=20, rotation_range=10, crop_size=None, rand_crop=False, align_heading=False, test_only=False):
    SatMap_process_sidelength = utils.get_process_satmap_sidelength()

    satmap_transform = transforms.Compose([
        transforms.Resize(size=[SatMap_process_sidelength, SatMap_process_sidelength]),
        transforms.ToTensor(),
    ])

    Grd_h = GrdImg_H
    Grd_w = GrdImg_W

    grdimage_transform = transforms.Compose([
        transforms.Resize(size=[Grd_h, Grd_w]),
        transforms.ToTensor(),
    ])
    
    if crop_size is None:
        crop_size = utils.SatMap_process_sidelength
    

    train_set = SatGrdDataset(root=root_dir, file=train_file,
                              transform=(satmap_transform, grdimage_transform),
                              shift_range_lat=shift_range_lat,
                              shift_range_lon=shift_range_lon,
                              rotation_range=rotation_range, 
                              crop_size=crop_size, 
                              rand_crop=rand_crop, 
                              align_heading=align_heading, 
                              test_only=test_only,)
    # print(train_set.meter_per_pixel)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False, pin_memory=True,
                              num_workers=num_thread_workers, drop_last=False)
    
    return train_loader

def load_test_data(batch_size, shift_range_lat=20, shift_range_lon=20, rotation_range=10, crop_size=None, rand_crop=False, align_heading=False):
    SatMap_process_sidelength = utils.get_process_satmap_sidelength()

    satmap_transform = transforms.Compose([
        transforms.Resize(size=[SatMap_process_sidelength, SatMap_process_sidelength]),
        transforms.ToTensor(),
    ])

    Grd_h = GrdImg_H
    Grd_w = GrdImg_W

    grdimage_transform = transforms.Compose([
        transforms.Resize(size=[Grd_h, Grd_w]),
        transforms.ToTensor(),
    ])

    # # Plz keep the following two lines!!! These are for fair test comparison.
    # np.random.seed(2022)
    # torch.manual_seed(2022)

    test_set = SatGrdDataset(root=root_dir, file=test1_file,
                            transform=(satmap_transform, grdimage_transform),
                            shift_range_lat=shift_range_lat,
                            shift_range_lon=shift_range_lon,
                            rotation_range=rotation_range, crop_size=crop_size, rand_crop=rand_crop, align_heading=align_heading)

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True,
                            num_workers=num_thread_workers, drop_last=False)
    return test_loader

if __name__ == '__main__':
    batch_size = 2
    train_loader = load_train_data(batch_size, crop_size=1200)
    print(f"total: {len(train_loader)} batches")
    min_x = float('inf')
    max_x = float('-inf')
    min_y = float('inf')
    max_y = float('-inf')
    for epoch in range(1):
        #train_loader.sampler.set_epoch(epoch)
        # sat_map, left_camera_k, grd_left_imgs[0], \
        # torch.tensor(-gt_shift_x, dtype=torch.float32).reshape(1), \
        # torch.tensor(-gt_shift_y, dtype=torch.float32).reshape(1), \
        # torch.tensor(theta, dtype=torch.float32).reshape(1), \
        # file_name
        for step, data in enumerate(train_loader):
            sat_map, grd_img, pos_grid, grid_info = data
            # position_info = torch.tensor(torch.stack([grid_info["norm_center_x"], grid_info["norm_center_y"], grid_info["start_x"], grid_info["start_y"], grid_info["end_x"], grid_info["end_y"]], dim=-1))
            print(grd_img.shape)
            assert False
            # batch_size = sat_map.shape[0]
            # info = grid_info
            # pos_grids = []
            # for i in range(batch_size):
            #     xs = torch.linspace(info["start_x"][i], info["end_x"][i], steps=info["grid_size"][i])
            #     ys = torch.linspace(info["start_y"][i], info["end_y"][i], steps=info["grid_size"][i])
            #     grid_x, grid_y = torch.meshgrid(xs, ys, indexing='ij')
            #     pos_grid = torch.stack([grid_x, grid_y])
            #     # print(pos_grid.shape)
            #     pos_grids.append(pos_grid)
                
            # pos_grids = torch.stack(pos_grids)
            # print(pos_grids.shape)
            # print(f'epoch: {epoch}, step：{step}')
            # print(grid_info)
            # min_x = min(min_x, torch.min(grid_info["start_x"]))
            # max_x = max(max_x, torch.max(grid_info["end_x"]))
            # min_y = min(min_y, torch.min(grid_info["end_y"]))
            # max_y = max(max_y, torch.max(grid_info["start_y"]))
            # print(f'batch min_x = {torch.min(grid_info["start_x"])}, max_x = {torch.max(grid_info["end_x"])}, min_y = {torch.min(grid_info["end_y"])}, max_y = {torch.max(grid_info["start_y"])}')

            # print(f'processed {batch_size * (step + 1)} images out of {len(train_loader) * batch_size} images, current min_x = {min_x}, max_x = {max_x}, min_y = {min_y}, max_y = {max_y}')

    # print(f"min_x = {min_x}, max_x = {max_x}, min_y = {min_y}, max_y = {max_y}")

    
    