import json
import random
import sys

import torch
from torch.autograd import Variable


class DatasetWoz(object):
	'''
	data container for woz dataset
	'''
	def __init__(self, config, percentage=1.0, use_cuda=False):
		# setup
		feat_file = config['DATA']['feat_file']
		text_file = config['DATA']['text_file']
		dataSplit_file = config['DATA']['dataSplit_file']
		vocab_file = config['DATA']['vocab_file']
		template_file = config['DATA']['template_file']
		self.template = template_file # for further scoring
		self.USE_CUDA = use_cuda
		
		# hyper-params
		self.batch_size = config.getint('DATA', 'batch_size')
		self.percentage = percentage # percentage of data used
		self.data   = {'train':[],'valid':[],'test':[]} 
		self.data_index  = {'train': 0, 'valid': 0, 'test': 0} # index for accessing data
		self.n_batch = {}
		self.shuffle = config.getboolean('DATA', 'shuffle')

		# load vocab from file
		self._loadVocab(vocab_file) # a list of vocab, andy
		
		# set input feature cardinality
		self._setCardinality(template_file)
		self.do_size = self.dfs[1] - self.dfs[0]
		self.da_size = self.dfs[2] - self.dfs[1]
		self.sv_size = self.dfs[3] - self.dfs[2]
		
		# initialise dataset
		self._setupData(text_file, feat_file, dataSplit_file)
		self.reset()

 
	def reset(self):
		self.data_index  = {'train': 0, 'valid': 0, 'test': 0}
		if self.shuffle:
			random.shuffle(self.data['train'])


	def next_batch(self, data_type='train'):
		def indexes_from_sentence(sentence, add_eos=False):
			indexes = [self.word2index[word] if word in self.word2index else self.word2index['UNK_token'] for word in sentence.split(' ')]
			if add_eos:
				return indexes + [self.word2index['EOS_token']]
			else:
				return indexes

		# Pad a with the PAD symbol
		def pad_seq(seq, max_length):
			seq += [self.word2index['PAD_token'] for i in range(max_length - len(seq))]
			return seq

		# turn list of word indexes into 1-hot matrix
		def getOneHot(indexes):
			res = []
			for index in indexes:
				hot = [0]*len(self.word2index)
				hot[index] = 1
				res.append(hot)
			return res

		# reading a batch
		start = self.data_index[data_type]
		end = self.data_index[data_type] + self.batch_size
		data = self.data[data_type][start:end]
		self.data_index[data_type] += self.batch_size

		sentences, refs, feats, featStrs = [], [], [], []
#		do_label, da_label, sv_label, sv_seqs = [], [], [], []
		sv_indexes = []

		for dial_idx, turn_idx, text, meta in data:
			text_ori, text_delex = text['ori'], text['delex']
			sentences.append(indexes_from_sentence(text_delex, add_eos=True))
			refs.append(text_delex)

			# get semantic feature
			do_idx, da_idx, sv_idx, featStr = self.getFeatIdx(meta)
			do_cond = [1 if i in do_idx else 0 for i in range(self.do_size)] # domain condition
			da_cond = [1 if i in da_idx else 0 for i in range(self.da_size)] # dial act condition
			sv_cond = [1 if i in sv_idx else 0 for i in range(self.sv_size)] # slot/value condition
			feats.append(do_cond + da_cond + sv_cond)
			featStrs.append(featStr)

#			# get labels for da, slots
			sv_indexes.append(sv_idx)

		# Zip into pairs, sort by length (descending), unzip
		# Note: _words and _seqs should be sorted in the same order
		seq_pairs = sorted(zip(sentences, refs, feats, featStrs, sv_indexes), key=lambda p: len(p[0]), reverse=True)
		sentences, refs, feats, featStrs, sv_indexes = zip(*seq_pairs)

		# Pad with 0s to max length
		lengths = [len(s) for s in sentences]
		sentences_padded = [pad_seq(s, max(lengths)) for s in sentences]

		# Turn (batch_size, max_len) into (batch_size, max_len, n_vocab)
		sentences = [getOneHot(s) for s in sentences_padded]

		input_var = Variable(torch.FloatTensor(sentences))
		label_var = Variable(torch.LongTensor(sentences_padded))
		feats_var = Variable(torch.FloatTensor(feats))

		if self.USE_CUDA:
			input_var = input_var.cuda()
			label_var = label_var.cuda()
			feats_var = feats_var.cuda()

		return input_var, label_var, feats_var, lengths, refs, featStrs, sv_indexes


	def _setCardinality(self, template_file):
		self.cardinality = []
		with open(template_file) as f:
			self.dfs = [0,0,0,0]
			for line in f.readlines():
				self.cardinality.append(line.replace('\n',''))
				if line.startswith('d:'):
					self.dfs[1]+=1
				elif line.startswith('d-a:'):
					self.dfs[2]+=1
				elif line.startswith('d-a-s-v:'):
					self.dfs[3]+=1
			for i in range(0, len(self.dfs)-1):
				self.dfs[i+1] = self.dfs[i] + self.dfs[i+1]


	def printDataInfo(self):
		print('***** DATA INFO *****')
		print('Using {}% of training data'.format(self.percentage*100))
		print('BATCH SIZE:', self.batch_size)
		
		print('Train:', len(self.data['train']), 'turns')
		print('Valid:', len(self.data['valid']), 'turns')
		print('Test:', len(self.data['test']), 'turns')
		print('# of turns', file=sys.stderr)
		print('Train:', len(self.data['train']), file=sys.stderr)
		print('Valid:', len(self.data['valid']), file=sys.stderr)
		print('Test:', len(self.data['test']), file=sys.stderr)
		print('# of batches: Train {} Valid {} Test {}'.format(self.n_batch['train'], self.n_batch['valid'], self.n_batch['test']))
		print('# of batches: Train {} Valid {} Test {}'.format(self.n_batch['train'], self.n_batch['valid'], self.n_batch['test']), file=sys.stderr)
		print('*************************\n')


	def _setupData(self, text_file, feat_file, dataSplit_file):
		with open(text_file) as f:
			dial2text = json.load(f)
		with open(feat_file) as f:
			dial2meta = json.load(f)

		with open(dataSplit_file) as f:
			dataSet_split = json.load(f)

		for data_type in ['train', 'valid', 'test']:
			for dial_idx, turn_idx, _ in dataSet_split[data_type]:
				# might have empty feat turn which is not in feat file
				if turn_idx not in dial2meta[dial_idx]:
					continue

				meta = dial2meta[dial_idx][turn_idx]
				text = dial2text[dial_idx][turn_idx]
				self.data[data_type].append((dial_idx, turn_idx, text, meta))

		# percentage of training data
		if self.percentage < 1:
			_len = len(self.data['train'])
			self.data['train'] = self.data['train'][:int(_len*self.percentage)]

		# setup number of batch
		for _type in ['train', 'valid', 'test']:
			self.n_batch[_type] = len(self.data[_type]) // self.batch_size

		self.printDataInfo()


	
	def _loadVocab(self,vocab_file):
		# load vocab
		self.word2index = {}
		self.index2word = {}
		idx = 0
		with open(vocab_file) as fin:
			for word in fin.readlines():
				word = word.strip().split('\t')[0]
				self.word2index[word] = idx
				self.index2word[idx] = word
				idx += 1


	def getFeatIdx(self, meta):
		feat_container = []
		do_idx, da_idx, sv_idx = [], [], []
		for da, slots in meta.items():
			do = da.split('-')[0]
			_do_idx = self.cardinality.index('d:'+do) - self.dfs[0]
			if _do_idx not in do_idx:
				do_idx.append(_do_idx)
			da_idx.append( self.cardinality.index('d-a:'+da) - self.dfs[1] )
			for _slot in slots: # e.g. ('Day', '1', 'Wednesday ')
				sv_idx.append( self.cardinality.index('d-a-s-v:'+da+'-'+_slot[0]+'-'+_slot[1]) - self.dfs[2] )
				feat_container.append( da+'-'+_slot[0]+'-'+_slot[1] )

		feat_container = sorted(feat_container) # sort SVs across DAs to make sure universal order
		feat = '|'.join(feat_container)

		return do_idx, da_idx, sv_idx, feat

class SimpleDatasetWoz(DatasetWoz):
    def __init__(self, config):
        vocab_file = config['DATA']['vocab_file']
        template_file = config['DATA']['template_file']
        self.batch_size = 1
        
        # load vocab from file
        self._loadVocab(vocab_file) # a list of vocab, andy
        
        # set input feature cardinality
        self._setCardinality(template_file)
        self.do_size = self.dfs[1] - self.dfs[0]
        self.da_size = self.dfs[2] - self.dfs[1]
        self.sv_size = self.dfs[3] - self.dfs[2]
        
