import cv2
import random
import numpy as np
from torch.utils.data import Dataset
import os
from torchvision import transforms

# load the yaw, pitch angles for the street-view images and yaw angles for the aerial view
import scipy.io as sio
from PIL import Image


class CVACTData(Dataset):
    # the path of your CVACT dataset
    def __init__(self, pathdir):
        self.data_list = pathdir
        print('InputData::__init__: load %s' % self.data_list)

        self.cur_id = 0  # for training
        self.id_list = []
        self.id_idx_list = []

        with open(self.data_list, 'r') as file:
            idx = 0
            for line in file:
                data = line.split('\t')
                grd = data[0]
                sat = data[1].splitlines()[0]

                self.id_list.append([grd, sat])
                self.id_idx_list.append(idx)
                idx += 1

        self.data_size = len(self.id_list)
        print('InputData::__init__: load', self.data_list, ' data_size =', self.data_size)

        self.satimage_transform = transforms.Compose([
            transforms.Resize(size=[256, 256]), # already 256 256
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.grdimage_transform = transforms.Compose([
            transforms.Resize(size=[128, 512]), # 256, 512
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.satimage_transform_noresize = transforms.Compose([
            # transforms.Resize(size=[256, 256]), # already 256 256
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.grdimage_transform_noresize = transforms.Compose([
            # transforms.Resize(size=[128, 512]), # 256, 512
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __len__(self):
        return len(self.id_list)

    def get_file_list(self):
        return self.id_list

    def __getitem__(self, idx):
        path_grd = self.id_list[idx][0]
        img_grd = Image.open(path_grd)
        img_grd_signal = np.asarray(img_grd)

        with Image.open(self.id_list[idx][0], 'r') as GrdImg:
            img_grd_rgb = GrdImg.convert('RGB')
            if self.grdimage_transform is not None:
                img_grd_rgb_trans = self.grdimage_transform(img_grd_rgb)
                grd_noresize = self.grdimage_transform_noresize(img_grd_rgb)


        path_sat = self.id_list[idx][1]
        img_sat = Image.open(path_sat)
        img_sat_signal = np.asarray(img_sat)
        with Image.open(self.id_list[idx][1], 'r') as SatImg:
            img_sat_rgb = SatImg.convert('RGB')
            if self.satimage_transform is not None:
                img_sat_rgb_trans = self.satimage_transform(img_sat_rgb)
                sat_noresize = self.satimage_transform_noresize(img_sat_rgb)

        return img_sat_rgb_trans, img_grd_rgb_trans, sat_noresize, grd_noresize, path_sat, path_grd  # (np.around(((512 - grd_shift) / 512 * 64) % 64)).astype(np.int)



class CVACTTrain(CVACTData):
    def __init__(self, pathdir='/cvatc/streetview.txt'):
        super().__init__(pathdir=pathdir)


class CVACTVal(CVACTData):
    def __init__(self, pathdir='/cvatc/streetview_test.txt'):
        super().__init__(pathdir=pathdir)