import sys
import os
sys.path.insert(1,os.getcwd())

import pickle
from torch.utils.data import DataLoader

f=open('/data/workspace/pengjie/MultiBench/datasets/data/humor.pkl','rb')
data=pickle.load(f)

def getdata(traindata,shuf=False,rate=1,batch_size=32,repeat=1,num_workers=0, debug=False, flip_label = False):
    if debug:
        if flip_label:
            traindatas = [[traindata['vision'][i],traindata['audio'][i],traindata['text'][i],1-traindata['labels'][i][0]] for i in range(int(len(traindata['vision'])*rate * 0.05))]*repeat
        else:
            traindatas = [[traindata['vision'][i],traindata['audio'][i],traindata['text'][i],traindata['labels'][i][0]] for i in range(int(len(traindata['vision'])*rate * 0.05))]*repeat
    else:
        if flip_label:
            traindatas = [[traindata['vision'][i],traindata['audio'][i],traindata['text'][i],1-traindata['labels'][i][0]] for i in range(int(len(traindata['vision'])*rate))]*repeat
        else:
            traindatas = [[traindata['vision'][i],traindata['audio'][i],traindata['text'][i],traindata['labels'][i][0]] for i in range(int(len(traindata['vision'])*rate))]*repeat
    #if rate < 1:
    #    for _,_,_,i in traindatas:
    #        print(i)
    return DataLoader(traindatas, shuffle=shuf, num_workers=num_workers, batch_size=batch_size)

def get_dataloader(rate=1,train_batch_size=32,repeat=1, num_workers = 0, debug=False, flip_label = False):
    # print(max(data['train']['labels']), min(data['train']['labels']))
    return getdata(data['train'],True,rate=rate,batch_size=train_batch_size,repeat=repeat, num_workers=num_workers,debug=debug, flip_label=flip_label),getdata(data['valid'], num_workers=num_workers,debug=debug, flip_label=flip_label),getdata(data['test'], num_workers=num_workers,debug=debug)
