import sys, os
# import pathlib
# temp = pathlib.PosixPath
# pathlib.PosixPath = pathlib.WindowsPath #hack fix to get it work on windows
from pathlib import Path
sys.path.append(os.path.dirname(sys.path[0]))
from abc import ABC
import pickle as pkl

'''
Abstract class defining dataset properties and functions

Datasets must be structured as follows:
# dataset_path / <sequence_id> / raw_images / <image files> (sorted in ascending filename order.)
# dataset_path / <sequence_id> / gt_data / <ground truth data files> (sorted in ascending filename order.)
# dataset_path / <sequence_id> / label.txt (sorted ascending filename order or simply one for entire sequence.)
# dataset_path / <sequence_id> / metadata.txt (sorted in ascending filename order or one for the entire sequence.)

All directories under dataset_path will be considered to be sequences containing data and labels.

The resulting RawImageDataset will be stored in the following location:
# dataset_path / <image_dataset_path>.pkl

The resulting SceneGraphDataset will be stored in the following location:
# dataset_path / <sg_dataset_path>.pkl
'''
class BaseDataset(ABC):
    def __init__(self, config):
        self.dataset_path = config.location_data["input_path"]
        self.config = config
        self.data = None
        self.labels = None
        self.dataset_save_path = config.location_data["data_save_path"]
        self.dataset_type = config.dataset_type
        self.action_types = None
        self.ignore = []
        self.folder_names = None


    #load/save data from dataset_path into data, labels, meta
    def save(self):
        with open(self.dataset_save_path, 'wb') as f:
            pkl.dump(self, f)

    def load(self):
        with open(self.dataset_save_path, 'rb') as f:
            return pkl.load(f)

'''
Dataset containing image data and associated information only.
'''
class RawImageDataset(BaseDataset):
    # REQ: the dataset that only contains raw images
    # REQ: this dataset can be used for scene graph extractor
    # REQ: this dataset can be used by CNN-based approach directly.
    def __init__(self, config = None):
        if config != None:
            super(RawImageDataset, self).__init__(config)
            self.im_height = None
            self.im_width =  None
            self.color_channels =  None
            self.frame_limit = config.frame_data["frames_limit"]
            self.dataset_type = 'image'
            self.data = {}   #{sequence{frame{frame_data}}} 
            self.labels = {} #{sequence{label}}
            self.action_types = {} #{sequence{action}}
            self.ignore = [] #sequences to ignore


'''
Dataset containing scene-graph representations of the road scenes.
This dataset is generated by the scene graph extractors and saved as a pkl file.
'''
class SceneGraphDataset(BaseDataset):
    # REQ: the dataset that only contains scene-graphs
    # meta data dict
    #action types dict
    # labels' dict
    # should be able to be converted into graph dataset or sequenced graph dataset.
    def __init__(self, config = None, scene_graphs= {}, action_types= {}, label_data= {},meta_data = {}):
        if config != None:
            super(SceneGraphDataset, self).__init__(config)
            self.dataset_type = 'scenegraph'
            self.scene_graphs = scene_graphs
            self.meta = meta_data
            self.labels = label_data
            self.action_types = action_types


    def process_carla_graph_sequences(self, scenegraphs, feature_list, frame_numbers = None, folder_name=None): 
        '''
            this is for creation of trainer input using carla data
            returns a dictionary containing sg metadata for each frame in a sequence
            default frame_numbers to len of sg dict that contains scenegraphs for each frame of the given sequence
            The self.scenegraphs_sequence should be having same length after the subsampling. 
            This function will get the graph-related features (node embeddings, edge types, adjacency matrix) from scenegraphs.
            in tensor formats.
        '''
        if frame_numbers == None:
            frame_numbers = sorted(list(scenegraphs.keys()))
        scenegraphs = [scenegraphs[frames] for frames in sorted(scenegraphs.keys())]
        sequence = []
        for idx, (scenegraph, frame_number) in enumerate(zip(scenegraphs, frame_numbers)):
            sg_dict = {}
            
            node_name2idx = {node:idx for idx, node in enumerate(scenegraph.g.nodes)}
    
            sg_dict['node_features']                    = scenegraph.get_carla_node_embeddings(feature_list)
            sg_dict['edge_index'], sg_dict['edge_attr'] = scenegraph.get_carla_edge_embeddings(node_name2idx)
            sg_dict['folder_name'] = folder_name
            sg_dict['frame_number'] = frame_number
            sg_dict['node_order'] = node_name2idx
            sequence.append(sg_dict)
    
        return sequence
  
    #===================================================================
    
    # this is for creation of trainer input using image data 
    #===================================================================
    
    def process_real_image_graph_sequences(self, scenegraphs, feature_list, frame_numbers=None, folder_name=None):
        '''
            The self.scenegraphs_sequence should be having same length after the subsampling. 
            This function will get the graph-related features (node embeddings, edge types, adjacency matrix) from scenegraphs.
            in tensor formats.
        '''
        if frame_numbers == None:
            frame_numbers = sorted(list(scenegraphs.keys()))
        scenegraphs = [scenegraphs[frames] for frames in sorted(scenegraphs.keys())]
        sequence = []
    
        for idx, (scenegraph, frame_number) in enumerate(zip(scenegraphs, frame_numbers)):
            sg_dict = {}
    
            node_name2idx = {node: idx for idx,
                             node in enumerate(scenegraph.g.nodes)}
    
            sg_dict['node_features'] = scenegraph.get_real_image_node_embeddings(feature_list)
            sg_dict['edge_index'], sg_dict['edge_attr'] = scenegraph.get_real_image_edge_embeddings(node_name2idx)
            sg_dict['folder_name'] = folder_name
            sg_dict['frame_number'] = frame_number
            sg_dict['node_order'] = node_name2idx
            sequence.append(sg_dict)
    
        return sequence
    
    
    
    #==================================================================

