from pathlib import Path
import json
import random
from torch.utils.data import Dataset
import torch
import util

snli_keys_no_parses = [
    'captionID', 'pairID', 
    'annotator_labels', 'gold_label', 
    'sentence1', 'sentence2'
]
snli_labels = ['entailment', 'contradiction', 'neutral']

class SNLIDataset(Dataset):
    def __init__(self, path, label_key, exclude_no_gold=True, exclude_ids=None):
        self.path = Path(path)
        self.label_key = label_key
        self.items = util.load_jsonl(self.path) 
        if label_key:
            self.items = [item for item in self.items if not item[label_key] == '-']
        if exclude_ids:
            self.items = [item for item in self.items if not item['pairID'] in exclude_ids]

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        raw_item = self.items[idx]
        item = {k: raw_item[k] for k in ['pairID', 'sentence1', 'sentence2']}
        item['label'] = raw_item[self.label_key] if self.label_key else None
        return item 

def add_roberta_se_fields(batch):
    """
    Add batch fields needed by the RoBERTa Self-Explaining model.

    Replicates this:
    https://github.com/ShannonAI/Self_Explaining_Structures_Improve_NLP_Models/blob/master/datasets/collate_functions.py#L19"
    """

    lengths = batch['attention_mask'].sum(dim=1) 
    max_sentence_length = lengths.max()
    device = max_sentence_length.device

    start_indexs = []
    end_indexs = []
    for i in range(1, max_sentence_length - 1):
        for j in range(i, max_sentence_length - 1):
            # # span大小为10
            # if j - i > 10:
            #     continue
            start_indexs.append(i)
            end_indexs.append(j)

    # generate span mask
    span_masks = []
    for input_ids, length in zip(batch['input_ids'], lengths):
        span_mask = []
        middle_index = input_ids.tolist().index(2)
        for start_index, end_index in zip(start_indexs, end_indexs):
            if 1 <= start_index <= length.item() - 2 and 1 <= end_index <= length.item() - 2 and (
                start_index > middle_index or end_index < middle_index):
                span_mask.append(0)
            else:
                span_mask.append(1e6)
        span_masks.append(span_mask)
    
    # add to output
    batch['start_indexs'] = torch.LongTensor(start_indexs).to(device)
    batch['end_indexs'] = torch.LongTensor(end_indexs).to(device)
    batch['span_masks'] = torch.LongTensor(span_masks).to(device)
    return batch # (input_ids, labels, length, start_indexs, end_indexs, span_masks)

def create_collate_fn(tokenizer, label_stoi, device, 
                      roberta_se=False, hypothesis_only=False):
    def collate_fn(batch):
        item_ids = [item['pairID'] for item in batch]
        if hypothesis_only:
            inputs = [item['sentence2'] for item in batch]
        else:
            inputs = [(item['sentence1'], item['sentence2']) for item in batch]
        sentence_pair_tokenized = tokenizer(inputs, padding=True, return_tensors='pt').to(device)
        if roberta_se:
            sentence_pair_tokenized = add_roberta_se_fields(sentence_pair_tokenized)
        label_idxs = torch.LongTensor([label_stoi[item['label']] if item['label'] else -1 for item in batch]).to(device)
        return item_ids, sentence_pair_tokenized, label_idxs
    return collate_fn

def sample_and_supplement(dataset_a, dataset_b, sample_n, replace=False):
    """ 
    Samples `sample_n` items from dataset_b and adds them to dataset_a.
    If `replace` is True then `sample_n` items are removed from dataset_a 
    (keeping the original size of the dataset).
    """
    b_ids = random.sample(range(len(dataset_b)), sample_n)
    if replace:
        a_ids = random.sample(range(len(dataset_a)), len(dataset_a) - sample_n)
        assert len(b_ids) + len(a_ids) == len(dataset_a)
        dataset_a = torch.utils.data.Subset(dataset_a, a_ids)
    dataset_b = torch.utils.data.Subset(dataset_b, b_ids)
    return torch.utils.data.ConcatDataset([dataset_a, dataset_b])

