import pickle
from tqdm import tqdm
from collections import defaultdict
from .base_dataset import DatasetWithPerturbation
from typing import Dict, List
from copy import deepcopy


class POPE_DS():
    pass


class SceneParsingOcclusion(DatasetWithPerturbation):
    def __init__(
        self,
        floc = './evaluation/sceneparsing_perturbed.pkl'
    ):
        with open(floc, 'rb') as f:
            self.ds = pickle.load(f)

        # build the group id index
        group_idxs_map = defaultdict(list)

        for idx, item in enumerate(self.ds):
            group_idxs_map[item['perturb_group_idx']].append(idx)
        
        self.group_idxs_map = group_idxs_map
    
    def __getitem__(self, item):
        # returns structured for our rollout scripts
        dset_item = self.full_item(item)
        return dset_item
    
    def full_item(self, item):
        return deepcopy(self.ds[item])

    def __len__(self):
        return len(self.ds)

    def get_perturbation_strength(self, item) -> Dict:
        dset_item = self.ds[item]
        return dset_item['perturb_strength']

    def get_perutrbation_group_ids(self, item) -> List:
        group_idx = self.ds[item]['perturb_group_idx']
        return self.group_idxs_map[group_idx]

    def get_problem_system_instruction(self):
        return ''