import random
from torch.utils.data import Dataset, ConcatDataset, DataLoader
import numpy as np
import torch
from NLPTasks_wo_SuperGLUE import *
from tqdm import tqdm
import itertools
import os
import h5py
import math
import copy
from pathlib import Path
import re
import json
import nltk
from difflib import SequenceMatcher


class DAInContextDataset(Dataset):
    SKIP_ATTRIBUTES = ['gt_x', 'gt_y']

    def __init__(self, config, data_path, train_path, tokenizer, is_training=False, is_root=True,
                 random_in_context=True, duplication=1):
        self.tokenizer = tokenizer
        self.is_training = is_training
        self.config = config
        self.sep_token_id = 0
        self.random_in_context = random_in_context
        self.duplication = duplication

        self.data_list = []
        data_instance_list, data_index_list = self.get_data_set(data_path)
        for index, data_instance in zip(data_index_list, data_instance_list):
            # for _ in range(self.config.oversample if self.is_training else self.config.eval_data_replication):
            for _ in range(self.duplication):
                self.data_list.append((index, data_instance))

        self.train_data_list = []
        train_data_instance_list, train_data_index_list = self.get_data_set(train_path)
        for index, data_instance in zip(train_data_index_list, train_data_instance_list):
            self.train_data_list.append((index, data_instance))
        self.train_data_index = [i for i in range(len(self.train_data_list))]

        if is_training:
            assert len(self.config.training_da_mode) > 0
            self.da_mode = self.config.training_da_mode
        else:
            assert len(self.config.eval_da_mode) > 0
            self.da_mode = self.config.eval_da_mode

        self.mode_func = {
            "tag": self.gen_from_tag_sequence,
            "nlu": self.gen_from_nlu,
        }

        if is_training:
            assert self.config.prefix_set_number == 0 or len(config.lm_gen_train_path_list) == 0 or len(
                config.lm_gen_train_path_list) == self.config.prefix_set_number
            for pid, d_path in enumerate(config.lm_gen_train_path_list):
                current_instance_list, current_index_list = self.get_data_set(d_path, filtering=True,
                                                                              index_starts=len(self.data_list))
                for index, data_instance in zip(current_index_list, current_instance_list):
                    self.data_list.append((index, data_instance))

        if is_root:
            print("Data Size %d" % len(self.data_list))

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        mode = random.choice(self.da_mode)
        data_generator = self.mode_func[mode]

        (index, data_instance) = self.data_list[idx]
        x_ids, y_ids, gt_x, gt_y = data_generator(data_instance, mask_token=True)
        input_ids = x_ids
        total_length = len(input_ids)
        if self.config.in_context_instance_count > 0:
            selected_instances = []
            selected_gt_list = []
            if self.random_in_context:
                in_context_samples = random.sample(self.train_data_index, k=self.config.in_context_instance_count)
            else:
                in_context_samples = self.get_in_context_samples(idx, k=self.config.in_context_instance_count)
            for d_index in in_context_samples:
                (train_index, train_data_instance) = self.train_data_list[d_index]
                if self.is_identical(train_data_instance, data_instance): continue
                if total_length <= self.config.max_length - 1:
                    x_ids, _, context_gt_x, _ = data_generator(train_data_instance, add_seperator=True)
                    if len(x_ids) + total_length <= self.config.max_length - 1:
                        selected_instances.append(x_ids)
                        selected_gt_list.append(context_gt_x)
                        total_length += len(x_ids)
                else:
                    break
            if len(selected_instances) > 0:
                if self.random_in_context:
                    selected_instance_num = random.choice([i for i in range(0, len(selected_instances) + 1)])
                else:
                    selected_instance_num = len(selected_instances)
                for instance in selected_instances[:selected_instance_num]:
                    input_ids = instance + input_ids

        input_ids.append(self.tokenizer.eos_token_id)

        input_np = np.array(input_ids).astype(np.int64)
        output_np = np.array(y_ids).astype(np.int64)

        return input_np, output_np, gt_x, gt_y, index, 0 if mode == "nlu" else 1

    def get_in_context_samples(self, idx, k):
        (index, data_instance) = self.data_list[idx]
        scores = []
        for train_idx, (_, train_data) in enumerate(self.data_list):
            if self.is_identical(train_data, data_instance):
                continue
            data_instance_str = ' '.join(data_instance[:-1])
            train_data_str = ' '.join(train_data[:-1])
            scores.append((train_idx, SequenceMatcher(None, data_instance_str, train_data_str).ratio()))
        scores.sort(key=lambda x: x[1], reverse=False)
        return [x[0] for x in scores[:k]]

    def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
        raise NotImplementedError

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        raise NotImplementedError

    def get_data_set(self, path, filtering=False, index_starts=0):
        raise NotImplementedError

    def is_identical(self, instance_a, instance_b):
        raise NotImplementedError


class BoolQInContext(DAInContextDataset):

    def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
        question, article, tag = data_instance

        if mask_token:
            if self.is_training:
                new_tag = tag
            else:
                new_tag = 'True' if random.random() < 0.5 else 'False'

            input_x = "Questions: <extra_id_0> Label: %s Article: %s" % (new_tag, article)
            input_y = "<extra_id_0> %s <extra_id_1>" % question
            gt_x = "Questions: <extra_id_0> |*| Label: %s |*| Article: %s" % (new_tag, article)
        else:
            input_x = "Questions: %s Article: %s" % (question, article)
            input_y = tag
            input_x = "Label: %s %s" % (tag, input_x)

        gt_y = input_y
        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        question, article, tag = data_instance
        input_x = "Questions: %s Article: %s" % (question, article)
        input_y = tag

        gt_x, gt_y = (question, article), tag

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y

    def get_data_set(self, path, filtering=False):
        data_list = []
        data_index = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                data_list.append((items['question'], items['passage'], str(items['label'])))
                data_index.append(items['idx'])
        return data_list, data_index

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]


class RTEInContext(DAInContextDataset):
    def get_data_set(self, path, filtering=False):
        data_list = []
        data_index = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                label = ' '.join(items['label'].split('_'))
                data_list.append((items['premise'], items['hypothesis'], label))
                data_index.append(items['idx'])
        return data_list, data_index

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        premise, hypothesis, tag = data_instance
        input_x = "Hypothesis: %s Premise: %s" % (hypothesis, premise)
        input_y = tag

        gt_x, gt_y = input_x, input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y

    def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
        premise, hypothesis, tag = data_instance
        input_x = "Hypothesis: %s Premise: %s" % (hypothesis, premise)
        input_y = tag

        gt_x, gt_y = input_x, input_y

        if mask_token:
            seed = random.random()
            if seed <= 0.33:
                input_x = "Hypothesis: %s Label: %s Premise: <extra_id_0>" % (hypothesis, tag)
                input_y = "<extra_id_0> %s <extra_id_1>" % premise
            elif seed <= 0.67:
                input_x = "Hypothesis: <extra_id_0> Label: %s Premise: %s" % (tag, premise)
                input_y = "<extra_id_0> %s <extra_id_1>" % hypothesis
            else:
                input_x = "Hypothesis: <extra_id_0> Label: %s Premise: <extra_id_1>" % tag
                input_y = "<extra_id_0> %s <extra_id_1> %s <extra_id_2>" % (hypothesis, premise)

        else:
            input_x = "Label: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y


class CBInContext(DAInContextDataset):
    def get_data_set(self, path, filtering=False):
        data_list = []
        data_index = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                label = items['label']
                data_list.append((items['premise'], items['hypothesis'], label))
                data_index.append(items['idx'])
        return data_list, data_index

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        premise, hypothesis, tag = data_instance
        input_x = "Hypothesis: %s Premise: %s" % (hypothesis, premise)
        input_y = tag

        gt_x, gt_y = input_x, input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y


class COPAIncontext(DAInContextDataset):
    def get_data_set(self, path, filtering=False, index_starts=0):
        data_list = []
        data_index = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                # label = items['label']
                # label = 'choice2' if int(items['label']) == 1 else 'choice1'
                label = '2' if int(items['label']) == 1 else '1'
                data_list.append((items['premise'], items['choice1'], items['choice2'], items['question'], label))
                data_index.append(items['idx'])
                if self.is_training:
                    label = '1' if int(items['label']) == 1 else '2'
                    data_list.append((items['premise'], items['choice2'], items['choice1'], items['question'], label))
                    data_index.append(len(data_index))

        return data_list, data_index

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1] \
               and instance_a[2] == instance_b[2] and instance_a[3] == instance_b[3]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        premise, choice1, choice2, question, tag = data_instance
        # input_x = "choice1: %s choice2: %s premise: %s question: %s" % (choice1, choice2, premise, question)
        input_x = "Solution1: %s Solution2: %s Premise: %s Question: What is the %s for this?" % \
                  (choice1, choice2, premise, question)
        input_y = tag

        gt_x, gt_y = input_x, input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y


class WiCInContext(DAInContextDataset):
    def get_data_set(self, path, filtering=False, index_starts=0):
        data_list = []
        data_index = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                label = str(items['label']).lower()
                data_list.append((items['sentence1'], items['sentence2'], items['start1'], items['word'], label))
                data_index.append(items['idx'])
        return data_list, data_index

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1] \
               and instance_a[2] == instance_b[2] and instance_a[3] == instance_b[3]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        sentence1, sentence2, start1, word, tag = data_instance
        input_x = "pos: %s sentence1: %s sentence2: %s word: %s" % (start1, sentence1, sentence2, word)
        input_y = tag

        gt_x, gt_y = input_x, input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y


class WSCInContext(DAInContextDataset):
    def mark_span(self, text, span1_index, span1_text, span2_index, span2_text):
        words = text.split(' ')
        span1_text = span1_text.replace('\n', ' ')
        span2_text = span2_text.replace('\n', ' ')
        span1_len = len(span1_text.split(' '))
        span2_len = len(span2_text.split(' '))
        unmarked_entity_1 = ' '.join(words[span1_index: span1_index + span1_len])
        unmarked_entity_2 = ' '.join(words[span2_index: span2_index + span2_len])
        match1 = re.search(span1_text.lower(), unmarked_entity_1.lower())
        match2 = re.search(span2_text.lower(), unmarked_entity_2.lower())
        assert match1 is not None and match2 is not None, \
            f'{span1_text}, {" ".join(words[span1_index: span1_index + span1_len])}\n' \
            f'{span2_text}, {" ".join(words[span2_index: span2_index + span2_len])}'
        match1 = match1.span()
        match2 = match2.span()

        marked_1 = unmarked_entity_1[:match1[0]] + unmarked_entity_1[match1[0]: match1[1]] \
                   + unmarked_entity_1[match1[1]:]

        marked_2 = unmarked_entity_2[:match2[0]] + "*" + unmarked_entity_2[match2[0]: match2[1]] \
                   + "*" + unmarked_entity_2[match2[1]:]

        if span1_index < span2_index:
            words = words[:span1_index] + [marked_1] + words[span1_index + span1_len: span2_index] + \
                    [marked_2] + words[span2_index + span2_len:]
        else:
            words = words[:span2_index] + [marked_2] + words[span2_index + span2_len: span1_index] + \
                    [marked_1] + words[span1_index + span1_len:]
        return ' '.join(words)

    def get_data_set(self, path, filtering=False, index_starts=0):
        data_list = []
        data_index = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                # label = items['label']
                text = items['text']
                span1_index = items['target']['span1_index']
                span1_text = items['target']['span1_text']
                span2_index = items['target']['span2_index']
                span2_text = items['target']['span2_text']
                marked_text = self.mark_span(text, span1_index, span1_text, span2_index, span2_text)
                # label = str(items['label'])
                label = span1_text
                true_label = str(items['label']).lower()
                if true_label == 'false':
                    label = '*' + label + '*'
                data_list.append((marked_text, label))
                data_index.append(items['idx'])
        return data_list, data_index

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        marked_text, tag = data_instance
        input_x = marked_text
        input_y = tag

        gt_x, gt_y = input_x, input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> Premise: %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y


class MultiRCInContext(DAInContextDataset):
    def get_data_set(self, path, filtering=False, index_starts=0):
        data_list = []
        data_index = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                passage = items['passage']
                passage_id = items['idx']
                text = passage['text']
                questions = passage['questions']
                for question in questions:
                    q = question['question']
                    q_id = question['idx']
                    answers = question['answers']
                    for ans in answers:
                        ans_id = ans['idx']
                        ans_text = ans['text']
                        label = 'True' if ans['label'] == 1 else 'False'
                        data_list.append((text, q, ans_text, label))
                        data_index.append('_'.join([str(passage_id), str(q_id), str(ans_id)]))
        return data_list, data_index

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1] \
               and instance_a[2] == instance_b[2]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        passage_text, question, answer_text, tag = data_instance
        input_x = " Question: %s Answer: %s Article: %s" % (question, answer_text, passage_text)
        input_y = tag

        gt_x, gt_y = input_x, input_y

        if mask_token:
            input_x = "Label: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Label: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y


class ReCoRDInContext(DAInContextDataset):
    # generative
    def get_data_set(self, path, filtering=False, index_starts=0):
        data_list = []
        data_index = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                source = items['source']
                passage = items['passage']['text']
                passage_id = items['idx']
                entities = items['passage']['entities']
                entities_token = set([passage[entity['start']: entity['end'] + 1] for entity in entities])
                qas = items['qas']
                for q_as in qas:
                    query = q_as['query']
                    q_id = q_as['idx']
                    answers = q_as['answers']
                    answers_token = set(ans_span['text'] for ans_span in answers)
                    entity_id = 0
                    candidates = answers_token if self.is_training else entities_token
                    for pos_entity in candidates:
                        label = pos_entity
                        data_list.append((source, passage, query, entities_token, label))
                        data_index.append('_'.join([str(passage_id), str(q_id), str(entity_id)]))
                        entity_id += 1

        return data_list, data_index

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1] \
               and instance_a[2] == instance_b[2]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
        source, passage, query, entities_token, tag = data_instance

        input_x = "Query: %s Passage: %s" % (query, passage)
        # input_x = "Query: %s Passage: %s" % (query, passage)
        input_y = tag

        gt_x, gt_y = input_x, input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)
        return x_ids, y_ids, gt_x, gt_y


def process_tensor(tensor_list, last_dim, output_mask=False):
    tensor_len = [d.shape[0] for d in tensor_list]
    tensor_max_lenth = max(tensor_len)
    d_type = tensor_list[0].dtype
    if last_dim > 0:
        tensor_np = np.zeros((len(tensor_list), tensor_max_lenth, last_dim), dtype=d_type)
    else:
        tensor_np = np.zeros((len(tensor_list), tensor_max_lenth), dtype=d_type)
    mask_np = np.zeros((len(tensor_list), tensor_max_lenth), dtype=np.float32)
    for i, (d, l) in enumerate(zip(tensor_list, tensor_len)):
        if l > 0:
            tensor_np[i, :l] = d
            mask_np[i, :l] = 1
    if output_mask:
        return torch.from_numpy(tensor_np), torch.from_numpy(mask_np)
    else:
        return torch.from_numpy(tensor_np)


def _data_wrapper(dataset):
    encoder_input_ids, encoder_mask = process_tensor([d[0] for d in dataset], 0, output_mask=True)
    decoder_input_ids, decoder_mask = process_tensor([d[1] for d in dataset], 0, output_mask=True)
    decoder_input_ids[decoder_mask == 0] = -100
    gt_y = [d[3] for d in dataset]
    gt_x = [d[2] for d in dataset]
    data_index = [d[4] for d in dataset]
    task_index = torch.tensor([0 for d in dataset]).long()
    task_type_index = torch.tensor([d[5] for d in dataset]).long()

    return {"encoder_input_ids": encoder_input_ids, "encoder_mask": encoder_mask,
            "decoder_input_ids": decoder_input_ids, "task_ids": task_index, "task_type_ids": task_type_index,
            "gt_x": gt_x, "gt_y": gt_y, "data_index": data_index}


TASK_DA_MAPPING = {
    'rte': RTEInContext,
    'boolq': BoolQInContext,
    'cb': CBInContext,
    'copa': COPAIncontext,
    'wsc': WSCInContext,
    'wic': WiCInContext,
    'multirc': MultiRCInContext,
    'record': ReCoRDInContext,
}


def get_single_h5py_nlp_data(config, path, train_path, split, batch_size, tokenizer, max_length, shuffle=False,
                             distributed=False, is_root=True, is_train=True, random_in_context=True, duplication=1):
    assert config.running_task in TASK_DA_MAPPING, "Cannot find Task %s" % config.running_task
    InContextCLS = TASK_DA_MAPPING[config.running_task]
    combined_dataset = InContextCLS(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root,
                                    random_in_context=random_in_context, duplication=duplication)

    if is_root:
        print("%s Data Size %d" % (split, len(combined_dataset)))

    if distributed:
        dist_sampler = torch.utils.data.distributed.DistributedSampler(combined_dataset, shuffle=shuffle)
        dist_loader = DataLoader(combined_dataset, pin_memory=True, batch_size=batch_size, num_workers=8,
                                 collate_fn=_data_wrapper, sampler=dist_sampler)
        return dist_loader
    else:
        data_loader = DataLoader(combined_dataset, pin_memory=True, batch_size=batch_size, num_workers=8,
                                 collate_fn=_data_wrapper, shuffle=shuffle)
        return data_loader
