import argparse
import collections
import json
import os
import random
import sys
import time
import six
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
if six.PY2:
	import cPickle as cPickle
else:
	import pickle

_PAD ='_PAD'
_GO = '<s>'

PAD_ID = 0
GO_ID = 1

START_VOCAB_SIZE = 2

def load_vocab(args):
	vocab = {}
	vocab_list = []
	vocab[_PAD] = PAD_ID
	vocab_list.append(_PAD)
	with open(args.vocab) as f:
		for idx, line in enumerate(f):
			word = line.strip()
			vocab[word] = len(vocab)
			vocab_list.append(word)
	print('Loaded vocab %s of size %d' % (args.vocab, len(vocab)))
	return vocab, vocab_list

def str_to_io_tensor(raw_str):
	grid_tensor = np.zeros(16 * 18 * 18, dtype=np.float32)
	idxes = raw_str.split(' ')
	for item in idxes:
		idx = int(item[:item.find(':')])
		grid_tensor[idx] = 1.0
	grid_tensor = np.reshape(grid_tensor, (16, 18, 18))
	return grid_tensor

def code_to_token_ids(raw_prog, vocab):
	prog = [vocab.get(w) for w in raw_prog]
	return prog

def np_to_tensor(inp, output_type, cuda_flag, volatile_flag=False):
	if output_type == 'float':
		inp_tensor = Variable(torch.FloatTensor(inp), volatile=volatile_flag)
	elif output_type == 'int':
		inp_tensor = Variable(torch.LongTensor(inp), volatile=volatile_flag)
	else:
		print('undefined tensor type')
	if cuda_flag:
		inp_tensor = inp_tensor.cuda()
	return inp_tensor

def get_dataset(filename, prog_format, target_vocab, target_vocab_list, args):
	data_file = open(filename, 'r')
	raw_lines = data_file.readlines()
	print('Number of samples in ' + filename + ': ', len(raw_lines))
	if (not args.seq_only) and (not args.repeat_only) and (not args.no_branch):
		return raw_lines
	lines = []
	for line in raw_lines:
		data_item = json.loads(line)
		prog = data_item['program_tokens']
		cnt_while = prog.count('WHILE')
		cnt_repeat = prog.count('REPEAT')
		cnt_if = prog.count('IF')
		cnt_if_else = prog.count('IFELSE')
		cnt_tot = cnt_while + cnt_repeat + cnt_if + cnt_if_else
		if args.seq_only:
			if cnt_tot == 0:
				lines.append(line)
		elif args.repeat_only:
			if cnt_while + cnt_if + cnt_if_else == 0:
				lines.append(line)
		elif args.no_branch:
			if cnt_if_else == 0:
				lines.append(line)
	if args.seq_only:
		print('Number of sequential programs in ' + filename + ': ', len(lines))
	elif args.repeat_only:
		print('Number of sequential and repeat programs in ' + filename + ': ', len(lines))
	elif args.no_branch:
		print('Number of programs without IFELSE in ' + filename + ': ', len(lines))
	return lines

def cal_data_stat(data):
	max_prog_len = 0
	while_cnt = {}
	if_cnt = {}
	repeat_cnt = {}
	total_cnt = {}
	for data_line in data:
		data_item = json.loads(data_line)
		max_prog_len = max(max_prog_len, len(data_item['program_tokens']))
		cnt_while = data_item['program_tokens'].count('WHILE')
		cnt_repeat = data_item['program_tokens'].count('REPEAT')
		cnt_if = data_item['program_tokens'].count('IF') + data_item['program_tokens'].count('IFELSE')
		cnt_total = cnt_while + cnt_repeat + cnt_if
		if cnt_while in while_cnt:
			while_cnt[cnt_while] += 1
		else:
			while_cnt[cnt_while] = 1
		if cnt_repeat in repeat_cnt:
			repeat_cnt[cnt_repeat] += 1
		else:
			repeat_cnt[cnt_repeat] = 1
		if cnt_if in if_cnt:
			if_cnt[cnt_if] += 1
		else:
			if_cnt[cnt_if] = 1
		if cnt_total in total_cnt:
			total_cnt[cnt_total] += 1
		else:
			total_cnt[cnt_total] = 1
	print('number of while: ')
	print(while_cnt)
	print('number of repeat: ')
	print(repeat_cnt)
	print('number of if: ')
	print(if_cnt)
	print('total number of complex control flows: ')
	print(total_cnt)
	return max_prog_len

def read_prog(data_item, prog_format, prog_vocab):
	if prog_format == 'P':
		prog = data_item['program_tokens']
	else:
		print('prog format undefined.')
	prog = code_to_token_ids(prog, prog_vocab)
	return prog

def read_traces(data_item, prog_vocab):
	train_traces = []
	val_trace = []
	for idx, io in enumerate(data_item['examples']):
		trace = io['actions']
		trace = code_to_token_ids(trace, prog_vocab)
		if idx == 5:
			val_trace.append(trace)
		else:
			train_traces.append(trace)
	return train_traces, val_trace

def read_ios(data_item):
	train_ios = []
	val_io = []
	for idx, io in enumerate(data_item['examples']):
		inpgrid_tensor = str_to_io_tensor(io['inpgrid_tensor'])
		outpgrid_tensor = str_to_io_tensor(io['outgrid_tensor'])
		if idx == 5:
			val_io.append((inpgrid_tensor, outpgrid_tensor))
		else:
			train_ios.append((inpgrid_tensor, outpgrid_tensor))
	return train_ios, val_io

def int_to_io_tensor(data):
	grid_tensor = np.zeros(16 * 18 * 18, dtype=np.float32)
	for idx in data.numpy():
		grid_tensor[idx] = 1.0
	grid_tensor = np.reshape(grid_tensor, (16, 18, 18))
	return grid_tensor

def read_random_ios(data_item):
	val_ios = []
	for inp, out in data_item:
		inpgrid_tensor = int_to_io_tensor(inp)
		outpgrid_tensor = int_to_io_tensor(out)
		val_ios.append((inpgrid_tensor, outpgrid_tensor))
	return val_ios

def get_batch(data, batch_size, prog_format, prog_vocab, start_idx=None):
	batch_data = []
	for i in range(batch_size):
		if start_idx is not None:
			if i + start_idx >= len(data):
				break
			data_line = data[i + start_idx]
		else:
			data_line = random.choice(data)
		data_item = json.loads(data_line)
		prog = read_prog(data_item, prog_format, prog_vocab)
		train_traces, val_trace = read_traces(data_item, prog_vocab)
		train_ios, val_io = read_ios(data_item)
		batch_data.append((train_ios, val_io, prog, train_traces, val_trace))

	return batch_data

def get_data(data, prog_format, prog_vocab):
	batch_data = []
	for data_line in data:
		data_item = json.loads(data_line)
		prog = read_prog(data_item, prog_format, prog_vocab)
		train_traces, val_trace = read_traces(data_item, prog_vocab)
		train_ios, val_io = read_ios(data_item)
		batch_data.append((train_ios, val_io, prog, train_traces, val_trace))
	return batch_data
