# Generating communication data for the agents

# Generate the training data 
import functools
import itertools
import random
import torch
from utils.conf import *

class DataLoader:
    
    def __init__(self) :
        self.generatePairs()
        if NUM_ATTRS == 1:
            self.createDataset1()
        elif NUM_ATTRS == 2:
            self.createDataset2()
        elif NUM_ATTRS == 3:
            self.createDataset3()
        

    '''def getBatch(self, batchSize):
        # sample a batch
        indices = torch.LongTensor(batchSize).random_(0, self.numInst['train'])
        # print(f'indices: {indices}')
        batch = self.data_d['train'][indices]
        return batch'''
    
    def getInstCount(self): return self.numInst
    
    def getCompleteData(self, dtype):
        indices = torch.LongTensor(torch.arange(0, self.numInst[dtype]))
        # return self.data[indices]
        return self.data[dtype][indices]
    
    def createDataset3(self):
        # aOutVocab = [chr(ii + 65) for ii in range(params['aOutVocab'])]
        attributes = ['segments', 'sectors', 'colors']
        # attributes = ['segments']
        # props = {'segments': ['seg1', 'seg2', 'seg3', 'seg4'],\
        #             'sectors': ['sec1', 'sec2', 'sec3', 'sec4'],\
        #             'colors': ['col1', 'col2', 'col3', 'col4']}
        props = {'segments': ['***','seg1', 'seg2', 'seg3'],\
                    'sectors': ['***','sec1', 'sec2', 'sec3'],\
                    'colors': ['***','col1', 'col2', 'col3']}
        # props = {'segments': ['seg1', 'seg2', 'seg3', 'seg4']}
        attrList = [props[ii] for ii in attributes]
        dataVerbose = list(itertools.product(*attrList))
        numImgs = len(dataVerbose)
        self.numInst = {}
        self.numInst['train'] = int(TRAINING_SIZE * numImgs)
        self.numInst['test'] = numImgs - self.numInst['train']
        numAttrs = 3
        #attrVals = ['***','seg1', 'seg2','seg3', 'sec1', 'sec2','sec3', 'col1', 'col2','col3']
        attrVals = ['***','seg1', 'seg2', 'seg3', 'sec1', 'sec2', 'sec3', 'col1', 'col2', 'col3']        
        # randomly select test
        splitData = {}
        splitData['test'] = random.sample(dataVerbose, self.numInst['test'])
        splitData['train'] = list(set(dataVerbose) - set(splitData['test']))
        self.attrVocab = {value: ii for ii, value in enumerate(attrVals)}
        self.invAttrVocab = {index: attr for attr, index in self.attrVocab.items()}
        self.data= {}
        for dtype in ['train', 'test']:
            data = torch.LongTensor(self.numInst[dtype], numAttrs)
            for ii, attrSet in enumerate(splitData[dtype]):
                data[ii] = torch.LongTensor([self.attrVocab[at] for at in attrSet])
            self.data[dtype] = data
        
        data2 = torch.LongTensor(self.numInst['train'], numAttrs)
        for ii, attrSet in enumerate(splitData['train']):
            data2[ii] = torch.LongTensor([0,0,0])
        self.data['void'] = data2
        
        
############################################################################################################
    # for only one concept
    
    def createDataset1(self):
        self.numInst = {}
        self.numInst['train'] = 4
        # take only one attribute
        attributes = ['segments']
        # props = {'segments': ['seg1', 'seg2', 'seg3', 'seg4']}
        attrList  = attrVals = ['seg1', 'seg2', 'seg3', 'seg4']

        attrVocab = {index: attr for attr, index in enumerate(attrList)}
        invAttrVocab = {index: attr for attr, index in attrVocab.items()}

        attrs = random.sample(attrList,4)

        self.data = torch.LongTensor(4,1)
        for ii, attr in enumerate(attrs):
            self.data[ii] = torch.LongTensor([attrVocab[attr]])

    def getBatch(self, batchSize):
        indices = torch.LongTensor(batchSize).random_(0, 4)
        batch = self.data[indices]
        return batch
    
############################################################################################################
    # for two concepts
    def createDataset2 (self):
        attributes = ['segments', 'sectors']
        props = {'segments': ['seg1', 'seg2', 'seg3', 'seg4'],\
                    'sectors': ['sec1', 'sec2', 'sec3', 'sec4']}
        attrList = [props[ii] for ii in attributes]
        dataVerbose = list(itertools.product(*attrList))
        numImgs = len(dataVerbose)
        self.numInst = {}
        self.numInst['train'] = int(1*numImgs)
        self.numInst['test'] = numImgs - self.numInst['train']
        numAttrs = 2
        attrVals = ['seg1', 'seg2', 'seg3', 'seg4', 'sec1', 'sec2', 'sec3', 'sec4']
        splitData = {}
        splitData['test'] = random.sample(dataVerbose, self.numInst['test'])
        splitData['train'] = list(set(dataVerbose) - set(splitData['test']))
        self.attrVocab = {value: ii for ii, value in enumerate(attrVals)}
        self.invAttrVocab = {index: attr for attr, index in self.attrVocab.items()}
        self.data = {}
        for dtype in ['train', 'test']:
            data = torch.LongTensor(self.numInst[dtype], numAttrs)
            for ii, attrSet in enumerate(splitData[dtype]):
                data[ii] = torch.LongTensor([self.attrVocab[at] for at in attrSet])
                self.data[dtype] = data

    
    def generatePairs(self):
        # source_index = np.random.choice(N_VERTEX, BATCH_SIZE)
        self.src_trg_pairs = torch.LongTensor(list(itertools.permutations(range(N_VERTEX),r=2)))
        self.numPairs = len(self.src_trg_pairs)
        # print(f'self src_tag_pairs = {self.src_trg_pairs}')
        # print(f'self.numPairs = {self.numPairs}')

    def getBatch(self, batchSize):
        indices = torch.LongTensor(batchSize).random_(0, self.numInst['train'])
        batch = self.data['train'][indices]
        batch2 = self.data['void'][indices]
        return batch, batch2

    def getNullBatch(self, batchSize):
        indices = torch.LongTensor(batchSize).random_(0, self.numInst['train'])
        batch2 = self.data['void'][indices]
        return batch2


    def get_vertex_batch(self, batchSize):
        indices = torch.LongTensor(batchSize).random_(0,self.numPairs)
        batch = self.src_trg_pairs[indices]
        batch2 = self.getNullBatch(batchSize)
        for i in range(ZERO_DATA):
            batch2[i] =  torch.LongTensor([random.choice(range(1,4)),0,0])
        return batch, batch2
    
   