import json
import nltk

import torch
from torch.utils.data import Dataset
from typing import Dict

from mix_eval.prompts.evaluation_prompts import (
construct_prompt_multichoice, 
construct_prompt_freeform,
)



def get_eval_dataset(args):
    if args.split == 'close_freeform' or args.split == 'close_multichoice' or args.split == 'close_freeform_hard' or args.split == 'close_multichoice_hard':
        return EvalDatasetCloseended(args)
    elif args.split == 'open' or args.split == 'open_hard':
        return EvalDatasetOpenended(args)
    else:
        raise ValueError(f"Split {args.split} not supported in {get_eval_dataset.__name__}.")
        

class EvalDatasetCloseended(Dataset):
    def __init__(self, args):
        super().__init__()
        
        self.args = args
        raw_inputs = []
        if args.split == 'close_freeform':
            print("Loading close-ended freeform data.")
            with open(args.data_path_freeform, 'r') as f:
                data = json.load(f)
                for id, d in data.items():
                    d['formated_input'] = construct_prompt_freeform(d)
                    d['id'] = id
                    raw_inputs.append(d)
        elif args.split == 'close_multichoice':
            print("Loading close-ended multichoice data.")
            with open(args.data_path_multiplechoice, 'r') as f:
                data = json.load(f)
                for id, d in data.items():
                    d['formated_input'] = construct_prompt_multichoice(d)
                    d['id'] = id
                    raw_inputs.append(d)
        elif args.split == 'close_freeform_hard':
            print("Loading close-ended freeform hard data.")
            with open(args.data_path_freeform_hard, 'r') as f:
                data = json.load(f)
                for id, d in data.items():
                    d['formated_input'] = construct_prompt_freeform(d)
                    d['id'] = id
                    raw_inputs.append(d)
        elif args.split == 'close_multichoice_hard':
            print("Loading close-ended multichoice hard data.")
            with open(args.data_path_multiplechoice_hard, 'r') as f:
                data = json.load(f)
                for id, d in data.items():
                    d['formated_input'] = construct_prompt_multichoice(d)
                    d['id'] = id
                    raw_inputs.append(d)
        else:
            raise ValueError(f"Split {args.split} not supported in {self.__class__.__name__}")
        
        # sort the lengths of the raw inputs to make the batching more efficient
        print("Sorting data based on input length.")
        raw_inputs = sorted(raw_inputs, key=lambda x: len(nltk.word_tokenize(x['formated_input'])), reverse=True)
        
        self.raw_inputs = raw_inputs          

    def __len__(self):
        return len(self.raw_inputs)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(
            raw_inputs=self.raw_inputs[i],
        )
            
class EvalDatasetOpenended(Dataset):
    def __init__(self, args):
        super().__init__()
        
        self.args = args
        raw_inputs = []
        if args.split == 'open':
            print("Loading open-ended data.")
            with open(args.data_path_open, 'r') as f:
                data = json.load(f)
                for id, d in data.items():
                    d['id'] = id
                    raw_inputs.append(d)
        elif args.split == 'open_hard':
            print("Loading open-ended-hard data.")
            with open(args.data_path_open_hard, 'r') as f:
                data = json.load(f)
                for id, d in data.items():
                    d['id'] = id
                    raw_inputs.append(d)

        self.raw_inputs = raw_inputs
        
    def __len__(self):
        return len(self.raw_inputs)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(
            raw_inputs=self.raw_inputs[i],
        )