# https://github.com/amaralibey/gsv-cities

import pandas as pd
from pathlib import Path
from PIL import Image, ImageFile, UnidentifiedImageError
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T

default_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# NOTE: Hard coded path to dataset folder 
BASE_PATH = '/database/dkim/gsv_cities/'


if not Path(BASE_PATH).exists():
    raise FileNotFoundError(
        'BASE_PATH is hardcoded, please adjust to point to gsv_cities')

class GSVCitiesDataset(Dataset):
    def __init__(self,
                 cities=['London', 'Boston'],
                 img_per_place=4,
                 min_img_per_place=4,
                 random_sample_from_each_place=True,
                 transform=default_transform,
                 base_path=BASE_PATH
                 ):
        super(GSVCitiesDataset, self).__init__()
        self.base_path = base_path
        self.cities = cities

        assert img_per_place <= min_img_per_place, \
            f"img_per_place should be less than {min_img_per_place}"
        self.img_per_place = img_per_place
        self.min_img_per_place = min_img_per_place
        self.random_sample_from_each_place = random_sample_from_each_place
        self.transform = transform
        
        # generate the dataframe contraining images metadata
        self.dataframe = self.__getdataframes()
        
        # get all unique place ids
        self.places_ids = pd.unique(self.dataframe.index)
        self.total_nb_images = len(self.dataframe)
        
    def __getdataframes(self):
        ''' 
            Return one dataframe containing
            all info about the images from all cities

            This requieres DataFrame files to be in a folder
            named Dataframes, containing a DataFrame
            for each city in self.cities
        '''
        # read the first city dataframe
        df = pd.read_csv(self.base_path+'Dataframes/'+f'{self.cities[0]}.csv')
        df = df.sample(frac=1)  # shuffle the city dataframe
        

        # append other cities one by one
        for i in range(1, len(self.cities)):
            tmp_df = pd.read_csv(
                self.base_path+'Dataframes/'+f'{self.cities[i]}.csv')

            # Now we add a prefix to place_id, so that we
            # don't confuse, say, place number 13 of NewYork
            # with place number 13 of London ==> (0000013 and 0500013)
            # We suppose that there is no city with more than
            # 99999 images and there won't be more than 99 cities
            # TODO: rename the dataset and hardcode these prefixes
            prefix = i
            tmp_df['place_id'] = tmp_df['place_id'] + (prefix * 10**5)
            tmp_df = tmp_df.sample(frac=1)  # shuffle the city dataframe
            
            df = pd.concat([df, tmp_df], ignore_index=True)

        # keep only places depicted by at least min_img_per_place images
        res = df[df.groupby('place_id')['place_id'].transform(
            'size') >= self.min_img_per_place]
        return res.set_index('place_id')
    
    def __getitem__(self, index):
        place_id = self.places_ids[index]
        
        # get the place in form of a dataframe (each row corresponds to one image)
        place = self.dataframe.loc[place_id]
        
        # sample K images (rows) from this place
        # we can either sort and take the most recent k images
        # or randomly sample them
        if self.random_sample_from_each_place:
            place = place.sample(n=self.img_per_place)
        else:  # always get the same most recent images
            place = place.sort_values(
                by=['year', 'month', 'lat'], ascending=False)
            place = place[: self.img_per_place]
            
        imgs = []
        imgs_pair = []

        folder_path = self.base_path


        for i, row in place.iterrows():
            img_name = self.get_img_name(row)
            img_path = folder_path + 'Images/' + \
                row['city_id'] + '/' + img_name
            

            img = self.image_loader(img_path)

            if self.transform is not None:
                img = self.transform(img)
            
            imgs.append(img)
     
        return torch.stack(imgs), torch.tensor(place_id).repeat(self.img_per_place)

    def __len__(self):
        '''Denotes the total number of places (not images)'''
        return len(self.places_ids)

    @staticmethod
    def image_loader(path):
        try:
            return Image.open(path).convert('RGB')
        except UnidentifiedImageError:
            print(f'Image {path} could not be loaded')
            return Image.new('RGB', (224, 224))

    @staticmethod
    def get_img_name(row):
        # given a row from the dataframe
        # return the corresponding image name

        city = row['city_id']
        
        # now remove the two digit we added to the id
        # they are superficially added to make ids different
        # for different cities
        pl_id = row.name % 10**5  #row.name is the index of the row, not to be confused with image name
        pl_id = str(pl_id).zfill(7)
        
        panoid = row['panoid']
        year = str(row['year']).zfill(4)
        month = str(row['month']).zfill(2)
        northdeg = str(row['northdeg']).zfill(3)
        lat, lon = str(row['lat']), str(row['lon'])
        name = city+'_'+pl_id+'_'+year+'_'+month+'_' + \
            northdeg+'_'+lat+'_'+lon+'_'+panoid+'.jpg'
        return name


# class GSVCitiesDataset(Dataset):
#     def __init__(self,
#                  cities=['London', 'Boston'],
#                  img_per_place=4,
#                  min_img_per_place=4,
#                  random_sample_from_each_place=True,
#                  transform=default_transform,
#                  base_path=BASE_PATH,
#                  reindex=True):
#         super(GSVCitiesDataset, self).__init__()
#         self.base_path = base_path
#         self.cities = cities

#         assert img_per_place <= min_img_per_place, \
#             f"img_per_place should be less than {min_img_per_place}"
#         self.img_per_place = img_per_place
#         self.min_img_per_place = min_img_per_place
#         self.random_sample_from_each_place = random_sample_from_each_place
#         self.transform = transform

#         # 데이터프레임 생성
#         self.dataframe = self.__getdataframes()

#         # 고유 place_id 리스트
#         self.places_ids = pd.unique(self.dataframe.index)
#         self.total_nb_images = len(self.dataframe)

#         # 라벨 재정의 여부
#         self.reindex_enabled = reindex
#         self.label_map = None
#         if self.reindex_enabled:
#             self.label_map = {pid: i for i, pid in enumerate(self.places_ids)}

#     def get_num_classes(self):
#         if self.label_map is None:
#             return len(set(self.places_ids))
#         else:
#             return len(self.label_map)

#     def __getdataframes(self):
#         df = pd.read_csv(self.base_path + 'Dataframes/' + f'{self.cities[0]}.csv')
#         df = df.sample(frac=1)
#         for i in range(1, len(self.cities)):
#             tmp_df = pd.read_csv(self.base_path + 'Dataframes/' + f'{self.cities[i]}.csv')
#             tmp_df['place_id'] = tmp_df['place_id'] + (i * 10**5)
#             tmp_df = tmp_df.sample(frac=1)
#             df = pd.concat([df, tmp_df], ignore_index=True)
#         res = df[df.groupby('place_id')['place_id'].transform('size') >= self.min_img_per_place]
#         return res.set_index('place_id')

#     def __getitem__(self, index):
#         place_id = self.places_ids[index]
#         place = self.dataframe.loc[place_id]

#         if self.random_sample_from_each_place:
#             place = place.sample(n=self.img_per_place)
#         else:
#             place = place.sort_values(by=['year', 'month', 'lat'], ascending=False)
#             place = place[: self.img_per_place]

#         folder_path = self.base_path

#         imgs = []
#         for _, row in place.iterrows():
#             img_name = self.get_img_name(row)
#             img_path = folder_path + 'Images/' + row['city_id'] + '/' + img_name
#             img = self.image_loader(img_path)
#             if self.transform is not None:
#                 img = self.transform(img)
#             imgs.append(img)

#         label_value = self.label_map[int(place_id)] if self.reindex_enabled else int(place_id)
#         label_tensor = torch.tensor(label_value).repeat(self.img_per_place)

#         return torch.stack(imgs), label_tensor

#     def __len__(self):
#         return len(self.places_ids)

#     @staticmethod
#     def image_loader(path):
#         try:
#             return Image.open(path).convert('RGB')
#         except UnidentifiedImageError:
#             print(f'Image {path} could not be loaded')
#             return Image.new('RGB', (224, 224))

#     @staticmethod
#     def get_img_name(row):
#         city = row['city_id']
#         pl_id = row.name % 10**5
#         pl_id = str(pl_id).zfill(7)
#         panoid = row['panoid']
#         year = str(row['year']).zfill(4)
#         month = str(row['month']).zfill(2)
#         northdeg = str(row['northdeg']).zfill(3)
#         lat, lon = str(row['lat']), str(row['lon'])
#         name = f"{city}_{pl_id}_{year}_{month}_{northdeg}_{lat}_{lon}_{panoid}.jpg"
#         return name