import torch
from torch.utils.data.dataset import Dataset

class RelationExtractionDataset(Dataset):
    def __init__(self, data, tokenizer, device, ds_name='mquake', num_classes=5):
        all_prompts = []
        str_idxs = []

        if ds_name == 'mquake':
            for d in data:
                all_prompts.extend(d['questions'])

                for q_key in d['relationship_indices'].keys():
                    current_idxs = []
                    for r_key in d['relationship_indices'][q_key].keys():
                        idx = d['relationship_indices'][q_key][r_key]
                        current_idxs.append(idx)
                    str_idxs.append(current_idxs)

        elif ds_name == 'rippleedit':
            for d in data:
                all_prompts.append(d['prompt'])
                str_idxs.extend([d['relationship_indices']])
                
        else:
            raise Exception('Not a useable dataset.')

        self.tokens = tokenizer(all_prompts, padding=True, truncation=False, return_tensors='pt').to(device)
        self.labels = []
        pop_list = []
        for i in range(self.tokens.input_ids.shape[0]):
            tok = self.tokens.input_ids[i]
            idxs = str_idxs[i]
            if len(idxs) > 2:
                continue

            tok_pos = []
            pos = 0
            for j,t in enumerate([tokenizer.decode(x) for x in tok]):
                if t in ['[CLS]', '[SEP]', '[PAD]']:
                    continue
                tok_pos.append((pos, pos+len(t)))
                if t in ['\'', '\"', 's']:
                    pos += len(t)
                elif '##' in t:
                    pos+= len(t) - 2
                else:
                    pos += len(t) + 1

            token_locs = {1 : [], 2 : [], 3 : [], 4 : []}
            for j, r in enumerate(idxs):
                a_flag = False
                for r_idx in r:
                    b_flag = False
                    for k, tp in enumerate(tok_pos):
                        if r_idx[0] == tp[0] and r_idx[1] == tp[1]:
                            token_locs[j+1].append(k+1)
                            break
                        elif r_idx[0] == tp[0] and r_idx[1] < tp[1]:
                            a_flag = True
                            token_locs[j+1].append(k+1)
                            break
                        elif r_idx[0] > tp[0] and r_idx[1] < tp[1] and a_flag:
                            token_locs[j+1].append(k+1)
                            break
                        elif r_idx[0] > tp[0] and r_idx[1] == tp[1] and a_flag:
                            token_locs[j+1].append(k+1)
                            break
                        elif r_idx[0] == tp[0] and r_idx[1] > tp[1]:
                            token_locs[j+1].append(k+1)
                            b_flag = True
                        elif r_idx[0] < tp[0] and r_idx[1] > tp[1] and b_flag:
                            token_locs[j+1].append(k+1)
                        elif r_idx[0] < tp[0] and r_idx[1] == tp[1] and b_flag:
                            token_locs[j+1].append(k+1)
                            break

            label = torch.zeros((tok.shape[0], num_classes))
            for loc_key in token_locs.keys():
                for t_pos in token_locs[loc_key]:
                    label[t_pos, loc_key] = 1.0

            for j in range(label.shape[0]):
                if 1 not in label[j]:
                    label[j, 0] = 1.0

            if ds_name == 'mquake' and torch.unique(torch.argmax(label, dim=-1)).shape[0] < 3:
                pop_list.append(i)
            self.labels.append(label.unsqueeze(0))

        self.labels = torch.vstack(self.labels)  
        for val in reversed(pop_list):
            self.labels = torch.vstack([self.labels[:val], self.labels[val+1:]])
            self.tokens.input_ids = torch.vstack([self.tokens.input_ids[:val], self.tokens.input_ids[val+1:]])
            self.tokens.attention_mask = torch.vstack([self.tokens.attention_mask[:val], self.tokens.attention_mask[val+1:]])
            self.tokens.token_type_ids = torch.vstack([self.tokens.token_type_ids[:val], self.tokens.token_type_ids[val+1:]])
        self.tokenizer = tokenizer

    def __len__(self):
        return self.labels.shape[0]
    
    def __getitem__(self, idx):
        item = {'input_ids' : self.tokens.input_ids[idx], 'attention_mask' : self.tokens.attention_mask[idx], 'token_type_ids' : self.tokens.token_type_ids[idx]}
        item['label'] = self.labels[idx]
        print([self.tokenizer.decode(x) for x in item['input_ids']])
        print(item['label'])
        exit()
        return item
