import cv2
import random
import numpy as np
from torch.utils.data import Dataset
import os
from torchvision import transforms

# load the yaw, pitch angles for the street-view images and yaw angles for the aerial view
import scipy.io as sio
# import ldm.modules.image_degradation.utils_image as util
from PIL import Image


def min_max_norm(x):
    x_ = np.reshape(x, [256 * 256 * 3])  # flattern
    mean_ = np.mean(x_)
    min_ = min(x_)
    max_ = max(x_)
    x = x.astype(np.float32)
    x = (x[:, :, :] - mean_) / float(max_ - min_)
    return x


class CVUSAData(Dataset):
    # the path of your CVACT dataset
    def __init__(self, root, datalist, crop=False):
        self.img_root = root
        self.data_list = datalist

        print('InputData::__init__: load %s' % self.data_list)
        self.__cur_id = 0  # for training
        self.id_list = []
        self.id_idx_list = []
        with open(self.data_list, 'r') as file:
            idx = 0
            for line in file:
                data = line.split(',')
                pano_id = (data[0].split('/')[-1]).split('.')[0]

                grd = self.img_root + data[1]
                sat = self.img_root + data[0]
                g2a = self.img_root + 'g2a/' + pano_id + '.png'
                a2g = self.img_root + 'a2g/' + pano_id + '.png'
                polar = self.img_root + data[0].replace('bing', 'polar').replace('jpg', 'png')

                self.id_list.append([grd, sat, g2a, a2g, polar])
                self.id_idx_list.append(idx)
                idx += 1
        self.data_size = len(self.id_list)
        print('InputData::__init__: load', self.data_list, ' data_size =', self.data_size)
        
        if crop == True:
            self.satimage_transform = transforms.Compose([
                transforms.CenterCrop(256),  # 256, 1024
                transforms.ToTensor(),
                # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        else:
            self.satimage_transform = transforms.Compose([
                transforms.Resize(size=[256, 256]),  # 256, 1024
                transforms.ToTensor(),
                # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])

        self.grdimage_transform = transforms.Compose([
            transforms.Resize(size=[128, 512]),  # 256, 1024
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.satimage_transform_noresize = transforms.Compose([
            # transforms.Resize(size=[256, 256]), # 256, 1024
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.grdimage_transform_noresize = transforms.Compose([
            # transforms.Resize(size=[128, 512]), # 256, 1024
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __len__(self):
        return len(self.id_list)

    def get_file_list(self):
        return self.id_list

    def __getitem__(self, idx):
        # grd_noise = 360
        # FOV = 360
        # grd_width = int(FOV / 360 * 512)

        path_grd = self.id_list[idx][0]

        with Image.open(self.id_list[idx][0], 'r') as GrdImg:
            img_grd = GrdImg.convert('RGB')
            if self.grdimage_transform is not None:
                img_grd_resize = self.grdimage_transform(img_grd)
                img_grd_noresize = self.grdimage_transform_noresize(img_grd)

        path_sat = self.id_list[idx][1]

        with Image.open(self.id_list[idx][1], 'r') as SatImg:
            img_sat = SatImg.convert('RGB')
            if self.satimage_transform is not None:
                img_sat_resize = self.satimage_transform(img_sat)
                img_sat_noresize = self.satimage_transform_noresize(img_sat)

        return img_sat_resize, img_grd_resize, img_sat_noresize, img_grd_noresize, path_sat, path_grd  # (np.around(((512 - grd_shift) / 512 * 64) % 64)).astype(np.int)


class CVUSATrain(CVUSAData):
    def __init__(self, root='/data/CVUSA/', datalist='/data/CVUSA/splits/train-19zl.csv', crop=False):
        super().__init__(root=root, datalist=datalist, crop=crop)


class CVUSAVal(CVUSAData):
    def __init__(self, root='/data/CVUSA/', datalist='/data/CVUSA/splits/val-19zl.csv', crop=False):
        super().__init__(root=root, datalist=datalist, crop=crop)
