import os
import json
import nltk
nltk.download('punkt', quiet=True)

import torch
from torch.utils.data import Dataset
from typing import Dict

from mix_eval.prompts.evaluation_prompts import (
construct_prompt_text2action
)



def get_eval_dataset(args):
    if args.benchmark == 'mixeval_x_text2action':
        return EvalDatasetCloseended(args)
    else:
        raise ValueError(f"Benchmark {args.benchmark} not supported in {get_eval_dataset.__name__}.")
        

class EvalDatasetCloseended(Dataset):
    def __init__(self, args):
        super().__init__()
        
        self.args = args
        
        raw_inputs = []
        print("Loading text2action data.")
        data_path = args.data_path
        with open(data_path, 'r') as f:
            data = json.load(f)
            for id, d in data.items():
                d['formated_input'] = construct_prompt_text2action(d)
                d['id'] = id
                raw_inputs.append(d)
        
        self.raw_inputs = raw_inputs          

    def __len__(self):
        return len(self.raw_inputs)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(
            raw_inputs=self.raw_inputs[i],
        )
