import csv
import pickle
import requests
import sys

from collections import namedtuple
from urllib.parse import urlparse

from vilbert_embeddings_lib import *

from embeddings_serialize import Image_Caption_Embedding, serialize, deserialize

Options = namedtuple('Options', ['input_path', 'output_path', 'image_dir', 'embedding_type', 'pretrained_model', 'config_file', 'save_every', 'resume', 'verbose', 'image', 'caption', 'batch_size', 'is_valid'])

def uri_validator(x):
    try:
        result = urlparse(x)
        return all([result.scheme, result.netloc])
    except:
        return False

def print_short_help():
	print(f"Usage: python {sys.argv[0]} <input_path> [Options]")
	print("")
	print(f"For further information use python {sys.argv[0]} -h")

def print_help():
	print("Extracts ViLBERT embeddings for image-caption pairs.")
	print("")
	print("The input is a comma-separated values (.csv) file with two columns, labelled")
	print("'image' and 'caption'. It may optionally have a third column, labelled 'rating'")
	print("for use in creating training, test, or validation data.")
	print("Alternatively, it may have a column labelled 'num_ratings', followed by a column")
	print("labelled 'ratings', in which case each row may have a variable-length of data.")
	print("If the 'ratings' column and the variable-length data make up the last column in")
	print("each row, the 'num_ratings' column may be omitted.")
	print("This variant that lists individual ratings is useful for calculating Kendall tau")
	print("correlation by 'Method A'.")
	print("The image may be given as a URL or a path to a local file.")
	print("")
	print("The output is a .emb file containing serialized Image_Caption_Embedding objects,")
	print("which have the following fields:")
	print("  image, caption, image_embedding, caption_embedding, ratings,")
	print("where ratings is a potentially empty list of ratings for the image-caption pair.")
	print("")
	print(f"Usage: python {sys.argv[0]} <input_path> [Options]")
	print("")
	print("Options:")
	print("")
	print("-o, --output-path <path>")
	print("  Specifies the path to the output .emb file. If not provided, the default is")
	print("  embeddings/<input_filestem>-<embedding_type>-vilbert.emb")
	print("-i, --image-dir <path>")
	print("  Specifies a path, either absolute or relative to the current working directory")
	print("  to be prepended to all local image files in the input csv")
	print("-e, --embedding-type <type>")
	print("  Specifies what types of image and text embeddings to extract from ViLBERT")
	print(f"  Valid embedding types are {embedding_types}")
	print(f"  Defaults to '{embedding_types[0]}'' if not given")
	print("-p, --pretrained <path>")
	print("  Specifies the path to a specific pretrained model to load")
	print("  Defaults to 'vilbert-multi-task/save/multi_task_model.bin'")
	print("-c, --config <path>")
	print("  Specifies the path to the config file to use when loading a model")
	print("  Defaults to 'vilbert-multi-task/config/bert_base_6layer_6conect.json'")
	print("-s, --save-every <number>")
	print("  Specifies how often to write the currently processed items to the output file")
	print("  For saving progress in case the program is stopped before finishing")
	print("  Defaults to 500")
	print("-r, --resume")
	print("  Flag that, if present, loads the output .emb file and continues from the")
	print("  last item stored")
	print("  For use when only a portion of the input data was previously processed")
	print("-v, --verbose")
	print("  Flag that, if present, prints information to the console as embeddings are")
	print("  being produced")
	print("-b, --batch-size <number>")
	print("  Specifies how many image-caption pairs to process at a time")
	print("-ic, --image-caption <image> <caption>")
	print("  The next two arguments specify the image and caption, respectively, to be")
	print("  directly turned into ViLBERT embeddings")
	print("  If this option is given, no input_path needs to be specified, and the default")
	print("  output path is embeddings/<image_name>-<embedding_type>-vilbert.emb,")
	print("  where <image_name> is the filename of the image without the extension")
	print("-h, --help")
	print("  Flag that, if present, prints this message")

def parse_args():
	# Defaults
	input_path = ''
	output_path = ''
	image_dir = ''
	embedding_type = embedding_types[0]
	pretrained_model = "vilbert-multi-task/save/multi_task_model.bin"
	config_file = "vilbert-multi-task/config/bert_base_6layer_6conect.json"
	save_every = 500
	resume = False
	verbose = False
	batch_size = 1
	image = ''
	caption = ''
	is_valid = True
	
	arg_index = 1
	while arg_index < len(sys.argv) and is_valid:
		arg = sys.argv[arg_index]
		if arg in ('-o', '--output-path'):
			arg_index += 1
			if arg_index < len(sys.argv):
				output_path = sys.argv[arg_index]
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-i', '--image-dir'):
			arg_index += 1
			if arg_index < len(sys.argv):
				image_dir = sys.argv[arg_index]
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-e', '--embedding-type'):
			arg_index += 1
			if arg_index < len(sys.argv):
				e_type = sys.argv[arg_index]
				if e_type in embedding_types:
					embedding_type = e_type
				else:
					print(f"Invalid embedding type provided: {e_type}, defaulting to '{embedding_type}'")
					print(f"  Valid embedding types are {embedding_types}")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-p', '--pretrained'):
			arg_index += 1
			if arg_index < len(sys.argv):
				pretrained_model = sys.argv[arg_index]
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-c', '--config'):
			arg_index += 1
			if arg_index < len(sys.argv):
				config_file = sys.argv[arg_index]
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-s', '--save-every'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					save_every = int(sys.argv[arg_index])
					if save_every < 0:
						print(f"Saving interval cannot be negative")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not an integer")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-r', '--resume'):
			resume = True
		elif arg in ('-v', '--verbose'):
			verbose = True
		elif arg in ('-b', '--batch-size'):
			arg_index += 1
			if arg_index < len(sys.argv):
				try:
					batch_size = int(sys.argv[arg_index])
					if batch_size < 1:
						print(f"Batch size must be positive")
						is_valid = False
				except ValueError:
					print(f"Argument to {arg} was not an integer")
					is_valid = False
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-ic', '--image-caption'):
			arg_index += 1
			if arg_index + 1 < len(sys.argv):
				image = sys.argv[arg_index]
				caption = sys.argv[arg_index + 1]
				arg_index += 2
			else:
				print(f"Missing argument to {arg}")
				is_valid = False
		elif arg in ('-h', '--help'):
			print_help()
		elif input_path == '':
			input_path = arg
		else:
			print(f"Unrecognized option: {arg}")
			is_valid = False
		arg_index += 1
	
	if not is_valid or (input_path == '' and (image == '' or caption == '')):
		print_short_help()
	
	if output_path == '':
		pretrained_model_name = pretrained_model.split('/')[-1].split('.')[:-1][0]
		if input_path != '':
			input_filestem = '.'.join(input_path.split('/')[-1].split('.')[:-1])
			output_path = 'embeddings/' + input_filestem + '-' + embedding_type + '-' + pretrained_model_name + '-vilbert.emb'
		elif image != '':
			image_filestem = '.'.join(image.split('/')[-1].split('.')[:-1])
			output_path = 'embeddings/' + image_filestem + '-' + embedding_type + '-' + pretrained_model_name + '-vilbert.emb'
	
	options = Options(input_path, output_path, image_dir, embedding_type, pretrained_model, config_file, save_every, resume, verbose, image, caption, batch_size, is_valid)
	return options

def validate_options(options):
	if not options.is_valid:
		return False
	if options.input_path == '' and (options.image == '' or options.caption == ''):
		return False
	return True

def main():
	options = parse_args()
	
	if not validate_options(options):
		return
	
	# print(options)
	
	# Load correponding Vilbert model into global instance
	load_vilbert_model(options.embedding_type, options.pretrained_model, options.config_file, batch_size=options.batch_size)
	
	if options.input_path != '':
		with open(options.input_path, 'r', newline='', encoding='utf-8') as in_file:
			reader = csv.reader(in_file, delimiter=',', quotechar='"')
			header = next(reader)
			input_rows = [row for row in reader if len(row) > 0]
	elif options.image != '' and options.caption != '':
		header = ['image', 'caption']
		input_rows = [[options.image, options.caption]]
	
	image_col = -1
	caption_col = -1
	rating_col = -1
	num_ratings_col = -1
	ratings_col = -1
	
	for col, heading in enumerate(header):
		if heading == 'image':
			image_col = col
		elif heading == 'caption':
			caption_col = col
		elif heading == 'rating':
			rating_col = col
		elif heading == 'num_ratings':
			num_ratings_col = col
		elif heading == 'ratings':
			ratings_col = col
		# elif heading != '' and options.verbose:
		# 	print(f"WARNING: Unknown column '{heading}' in input csv file will be ignored.")

	if image_col == -1 or caption_col == -1:
		print(f"ERROR: Input csv file does not contain columns named 'image' and 'caption'. Aborting.")
		return
	
	if options.resume:
		with open(options.output_path, 'rb') as in_file:
			result = deserialize(in_file)
			# result = pickle.load(in_file, encoding="latin1")
		offset = 0
		# if len(result) > 0:
		# 	start_row = len(result) - 1
		# 	while start_row < len(input_rows) and (input_rows[start_row][image_col] != result[-1].image or input_rows[start_row][caption_col] != result[-1].caption):
		# 		start_row += 1
		# 	if start_row < len(input_rows):
		# 		start_row += 1
		# 	else:
		# 		print("Could not resume. Last completed item from loaded pkl file not found in input csv file.")
		# 		return
		# else:
		# 	start_row = 0
	else:
		result = []
		# start_row = 0
	
	image_name_batch = []
	image_batch = []
	caption_batch = []
	ratings_batch = []
	for i, row in enumerate(input_rows):
		image = row[image_col]
		caption = row[caption_col]
		if options.resume and i - offset < len(result):
			existing_image = result[i - offset].image
			existing_caption = result[i - offset].caption
			if image == existing_image and caption == existing_caption:
				if options.verbose:
					print("Skipping:", i, image, caption, "(already done)")
				continue
		if uri_validator(image):
			try:
				response = requests.get(image, stream = True)
			except Exception as e:
				print(f"Exception occurred retrieving image {image}: {str(e)}")
				if options.resume:
					offset += 1
				continue
			
			if response.status_code == 200:
				# Set decode_content value to True, otherwise the downloaded image file's size will be zero.
				response.raw.decode_content = True
				image_name_batch.append(image)
				image_batch.append(response.raw)
				caption_batch.append(caption)
			else:
				print(f"Image {image} could not be retreived. Status code {response.status_code} returned.")
				if options.resume:
					offset += 1
				continue
		else:
			image_name_batch.append(image)
			image_batch.append(options.image_dir + image)
			caption_batch.append(caption)
		
		if options.verbose:
			image_name = image.split('/')[-1]
			print(i, image_name, caption)
			# print(image_embedding.shape)
			# print(caption_embedding.shape)
		ratings = []
		if rating_col != -1:
			ratings = [float(row[rating_col])]
		elif ratings_col != -1:
			if num_ratings_col != -1:
				num_ratings = int(row[num_ratings_col])
				ratings = list(map(float, row[ratings_col:ratings_col + num_ratings]))
			else:
				ratings = list(map(float, row[ratings_col:]))
		ratings_batch.append(ratings)
		
		if (len(image_batch) == options.batch_size) or (i == len(input_rows) - 1):
			batch_size = len(image_batch)
			# image_embeddings, caption_embeddings = get_vilbert_embeddings(image_batch, caption_batch, verbose=options.verbose)
			embeddings_output = get_vilbert_embeddings(image_batch, caption_batch, verbose=options.verbose)
			if options.embedding_type == 'alignment':
				image_embeddings = embeddings_output
				caption_embeddings = np.asarray([[] for x in range(len(embeddings_output[0]))])
			else:
                                # Note: as of december 9th, 2022 checkin, embeddings_output is in the order of caption first and image second
				caption_embeddings, image_embeddings = embeddings_output

			for j in range(batch_size):
				vilbert_embedding = Image_Caption_Embedding(image_name_batch[j], caption_batch[j], image_embeddings[j], caption_embeddings[j], ratings_batch[j])

				if options.resume and i - batch_size + 1 + j - offset < len(result): #TODO: Fix offset calculation
					result.insert(i - batch_size + 1 + j - offset, vilbert_embedding)
				else:
					result.append(vilbert_embedding)
				if options.save_every > 0 and (i + 1) % options.save_every < batch_size:
					with open(options.output_path, 'wb') as outf:
						serialize(result, outf)
						# cPickle.dump(result, outf)
			image_name_batch = []
			image_batch = []
			caption_batch = []
			ratings_batch = []
	with open(options.output_path, 'wb') as outf:
		serialize(result, outf)
		# cPickle.dump(result, outf)

if __name__ == "__main__":
	main()
