import os
import pandas as pd
import torch
from torch.utils.data import Dataset
import json

model_name = 'sbert'
path = 'dataset/multimedia'
path_nodes = f'{path}/nodes'
path_edges = f'{path}/edges'
path_graphs = f'{path}/graphs'


class MultimediaDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.prompt = None
        self.graph = None
        with open(path + '/data.json', "r")as f:
            self.raw_data = f.readlines()
    def __len__(self):
        """Return the len of the dataset."""
        return len(self.raw_data)

    def __getitem__(self, index):
        data = json.loads(self.raw_data[index])
        question = f'{data["user_request"]}'
        graph = torch.load(f'{path}/graphs/{index}.pt')
        nodes = pd.read_csv(f'{path}/nodes.csv')
        desc = "and a list of tools:\n " + nodes.to_csv(index=False)
        label = ''

        edge = {}
        for e in data["task_links"]:
            edge[e["source"]] = e["target"]
        
        sources = list(edge.keys())
        targets = list(edge.values())
        root = None
        for s in sources:
            if s not in targets:
                root = s
                break
        sort_tool_list = [root]
        for i in range(len(edge)):
            sort_tool_list.append(edge[root])
            root = edge[root]
        for i, t in enumerate(sort_tool_list):
            label +=f"Tool{i + 1}: " +  t + "\n"

        return {
            'id': index,
            'image_id': data['id'],
            'question': question,
            'label': label,
            'graph': graph,
            'desc': desc,
        }

    def get_idx_split(self):

        # Load the saved indices
        with open(f'{path}/split/train_indices.txt', 'r') as file:
            train_indices = [int(line.strip()) for line in file]
        with open(f'{path}/split/val_indices.txt', 'r') as file:
            val_indices = [int(line.strip()) for line in file]
        with open(f'{path}/split/test_indices.txt', 'r') as file:
            test_indices = [int(line.strip()) for line in file]

        return {'train': train_indices, 'val': val_indices, 'test': test_indices}




if __name__ == '__main__':

    dataset = MultimediaDataset()

    data = dataset[0]
    for k, v in data.items():
        print(f'{k}: {v}')

    split_ids = dataset.get_idx_split()
    for k, v in split_ids.items():
        print(f'# {k}: {len(v)}')
