""" Stolen from DSG (https://github.com/j-min/DSG). 
"""

import string
from typing import Any, Callable, Dict, List
from tqdm import tqdm

########## template ########
_PROMPT_TEMPLATE = string.Template("""
$preamble

$examples

$test_input_output
""".strip())

# In-context examples

_EXAMPLES_TEMPLATE = string.Template("""
$input_name: 
$input
$output_name: 
$output""".strip())

_TEST_TEMPLATE = string.Template("""
$input_name: 
$test_input
$output_name: """.lstrip())
############################


########## preamble ########
_TUPLE_PREAMBLE = """Task: I will give several examples that contain the input and output. Then I will give a prompt, please generate skill-specific tuples from the prompt according to the specified task requirements and provided examples. The last element of a tuple should be enclosed in parentheses. 
First, extract all entities from the prompt and describe them using triples that start with "entity". If the entity represents a part of another entity, use "part" as the second element of the triple; otherwise, use "whole". And the last element of the triple is the name of the entity. For entity that is a part of a whole one, the name should include its main entity.  
Second, extract all the descriptions of style and describe them using tuples that start with "other". Descriptions with specific quantities should also be recorded. 
Third, extract all attributes of the entities such as color, material, shape and size. Descriptive phrase of appearance that contain verbs is also treated as an attribute. However, descriptions of location is not considered in this task. Describe them using triples that start with "attribute". For the second element of the triple, specify the category of the attribute. And for the final element, provide the entity and its attribute. 
Last, extract the spatial and action relationships between entities and describe them using triples or tuples that start with "relation". For spatial relationship, use "spatial" as the second element, followed by a third element that details the two entities and their spatial relation. For action-based relationship, use a tuple to describe it. And the last element describes two entities and the action relationship. 
Only give me those tuples without the prompt. Do not generate tuples describing the same thing among all the tasks. Do not generate tuples that are not explicitly described in the prompts. 
output format: id | tuple
""".strip()


_DEPENDENCY_PREAMBLE = """Task: I will give several examples that contain the input and output. Then I will give a prompt and skill-specific tuples. Please describe the parent tuples of each tuple according to the examples. 
Each tuple is preceded by a numerical id. If the fact within a tuple depends on another tuple, include the id of the dependency following the original tuple id in the output. Separate multiple dependencies with commas. Use 0 to indicate that a tuple has no dependencies. 
The numbering in the output should correspond one-to-one with the numbering of the input tuples. 
output format: id | dependencies
""".strip()


_QUESTION_PREAMBLE = """Task: I will give several examples that contain the input and output. Then I will give a prompt and skill-specific tuples. Please transform each tuple into a natural language question according to the examples. 
Each tuple is preceded by a numerical id. Transform the tuple to a yes-or-no question and output the question after its id according to the given examples. 
The numbering in the output should correspond one-to-one with the numbering of the input tuples. 
output format: id | question
""".strip()


_EXPANSION_PREAMBLE = """Task: I will give several examples that contain the input and output. Then I will give a prompt, skill-specific tuples and a yes or no list. Please add more tuples with additional information which do no conflict with the previous information. 
First, rewrite all the given tuples. Here are some rules for these tuples. 
Triples that start with "entity" describe the entity in the given prompt. If the entity represents a part of another entity, use "part" as the second element of the triple; otherwise, use "whole". And the last element of the triple is the name of the entity. For entity that is a part of a whole one, the name should include its main entity.  
Triples that start with "other" describe the total style of the image. Descriptions with specific quantities should also be recorded. 
Triples that start with "attribute" describe the attribute of the entity in them. For the second element of the triple, specify the category of the attribute. And for the final element, provide the entity and its attribute. 
Triples that start with "relation" or "action" describe the relation and action of the entities. For spatial relationship, use "spatial" as the second element, followed by a third element that details the two entities and their spatial relation. For action-based relationship, use a tuple to describe it. And the last element describes two entities and the action relationship. 
Second, if the list specifies 'no', expand the corresponding tuple with the same id. Generate several additional detailed descriptions such as objects, attributes and actions and write them after the tuples with continuing sequentially ids. You can create content that is not present in the given prompt.
Each tuple is preceded by a numerical id. If there is no false answer in the answer list, just give back the input tuples is ok. 
You should only give me the tuples. Do not give explations. 
output format: id | question
""".strip()


_REWRITTING_PREAMBLE = """Task: I will give several examples that contain the input and output. Then I will give a prompt, skill-specific tuples and tuples after supplementation. Please provide me with a complete, natural sentence based on the given prompt, which incorporates the content of the new tuples. 
Here are some rules for these tuples. 
Triples that start with "entity" describe the entity in the given prompt. If the entity represents a part of another entity, use "part" as the second element of the triple; otherwise, use "whole". And the last element of the triple is the name of the entity. For entity that is a part of a whole one, the name should include its main entity.  
Triples that start with "other" describe the total style of the image. Descriptions with specific quantities should also be recorded. 
Triples that start with "attribute" describe the attribute of the entity in them. For the second element of the triple, specify the category of the attribute. And for the final element, provide the entity and its attribute. 
Triples that start with "relation" or "action" describe the relation and action of the entities. For spatial relationship, use "spatial" as the second element, followed by a third element that details the two entities and their spatial relation. For action-based relationship, use a tuple to describe it. And the last element describes two entities and the action relationship. 
You should only give me the rewritten text. Do not add any irrelevant content that is not in the new tuples. 
""".strip()


_DECORATING_PREAMBLE = """Task: I will give several examples that contain the input and output. Then I will give a prompt. Please generate a sentence enriched with additional descriptive adjectives based on the prompt I provide, ensuring the original sentence remains unaltered and intact. 
Here are some rules for this task. 
You should first write the prompt completely, then add some extra words to make it more vivid. The extra words should not conflict with the previous description and separated by commas. 
Add at most two words to enhance the quality, such as: best quality, 4k, 8k, highres, masterpiece, fantasy art, highly detialed and so on. 
Add at most one word to describe the style. If the description is about a real scene, you can add realistic or photo-realistic; and if the description is about an art, add portraits, digital art or landscape. You can also select style such as impressionist, Vincent van Gogh, anime, concept artists and so on. 
Add at most one word about the background, like blue sky, crowded road, beautifully decorated wall, high mountain and so on. The background should not conflict with the given prompt. 
Add at most two words to describe the light, such as studio lighting, soft lighting, sun lighting, diffuse lighting and so on. 
You can add more descriptive words, but do not introduce additional objects or create conflicts with the aforementioned. 
You should only give me the rewritten text. Do not add any irrelevant content. 
""".strip()
############################


def make_prompt(
	examples: List[Dict[str, str]],
	test_input: str,
	preamble: str = _TUPLE_PREAMBLE,
	input_name: str = "input",
	output_name: str = "output",
	verbose: bool = False,
) -> str:
	"""Make a prompt by composing preamble, examples, and text input.

	Args:
	examples: list of examples - each example has keys ['input', 'output']
	test_input: test input string to generate output
	preamble: a task description for language model
	input_name: a verbalizer for input
	output_name: a verbalizer for output
	verbose: whether to print the prompt details (e.g., prompt length)

	Returns:
	prompt (str)

	Example output:

	Task: given input prompts, describe each scene with skill-specific tuples.
	Do not generate same tuples again. Do not generate tuples that are not
	explicitly described in the prompts.
	output format: id | tuple
	input: A red motorcycle parked by paint chipped doors.
	output: 0 | attribute - color (motorcycle, red)
	1 | attribute - state (door, paint chipped)
	2 | relation - spatial (motorcycle, door, next to)
	3 | attribute - state (motorcycle, parked)
	input: a large clock hangs from a building and reads 12:43.
	output: 0 | attribute - scale (clock, large)
	...
	input: A dignified beaver wearing glasses, a vest, and colorful neck tie. He stands next to a tall stack of books in a library.
	output:
	"""

	# examples: list of "input: $input \n output: $output"
	examples_str = []
	for example in examples:
		examples_str.append(
		_EXAMPLES_TEMPLATE.substitute(
			input_name=input_name,
			output_name=output_name,
			input=example["input"].strip(),
			output=example["output"].strip(),
		)
	)
	examples_str = "\n\n".join(examples_str)

	test_input_str = _TEST_TEMPLATE.substitute(
		input_name=input_name,
		output_name=output_name,
		test_input=test_input
	)

	prompt = _PROMPT_TEMPLATE.substitute(
		preamble=preamble,
		examples=examples_str,
		test_input_output=test_input_str,
	)

	# if verbose:
	# 	print(f"len(preamble): {len(preamble)}chars & {len(preamble.split())}words")
	# 	print(f"len(examples): {len(examples)}chars & {len(examples_str)}words")
	# 	print(f"len(total): {len(prompt)}chars & {len(prompt.split())}words")

	return prompt


def parse_with_input_name(text: str, input_name="input") -> str:
	"""Parse the first LM output by splitting with input verbalizer."""
	text = text.split(f"{input_name}:")[0]
	return text


def generate_with_in_context_examples(
	generate_fn: Callable[[str], str],
	id2inputs: Dict[str, Dict[str, str]],
	train_examples: List[Dict[str, Any]],
	preamble: str,
	input_name: str = "input",
	output_name: str = "output",
	parse_fn: Callable[[str], str] = parse_with_input_name,
	num_workers: int = 1,
	verbose: bool = False,
) -> Dict[str, Dict[str, str]]:
	"""Generate output with a language model with in-context examples.

	Args:
	generate_fn: a method that calls language model with a text input
	id2inputs: a input dictionary with following structure "id" (str) -> {
		"input": "test input prompt" (str) }
	train_examples: list of examples. Each example is a dict('input', 'output')
	preamble: a task description for language model
	input_name: a verbalizer for input
	output_name: a verbalizer for output
	parse_fn: a method that parses the output of language model.
	num_workers: number of workers for parallel call
	verbose: whether to print tqdm output / intermediate steps

	Returns:
	id2outputs: output dictionary with key with following structure
		"id" (str) -> {
		"input": "text prompt" (str),
		"output": "generated output" (str)
		}
	"""

	ids = list(id2inputs.keys())

	# 1) Create list of LM inputs
	total_kwargs = []

	for id_ in tqdm(
		ids,
		dynamic_ncols=True,
		ncols=80,
		disable=not verbose,
		desc="Preparing LM inputs",
	):
		test_input = id2inputs[id_]["input"]

		prompt = make_prompt(
			examples=train_examples,
			test_input=test_input,
			preamble=preamble,
			input_name=input_name,
			output_name=output_name,
			verbose=False,
		)

		total_kwargs.append({"prompt": prompt})

	# 2) Run LM calls
	if verbose:
		print(f"Running LM calls with {num_workers} workers.")
	if num_workers == 1:
		total_output = []
		for kwargs in tqdm(total_kwargs):
			prompt = kwargs["prompt"]
			output = generate_fn(prompt)
			total_output += [output]

	else:
		from multiprocessing import Pool
		with Pool(num_workers) as p:
			total_inputs = [d['prompt'] for d in total_kwargs]
			total_output = list(
				tqdm(p.imap(generate_fn, total_inputs), total=len(total_inputs)))

	# 3) Postprocess LM outputs
	id2outputs = {}

	for i, id_ in enumerate(
		tqdm(
				ids,
				dynamic_ncols=True,
				ncols=80,
				disable=not verbose,
				desc="Postprocessing LM outputs"
			)
		):

		test_input = id2inputs[id_]["input"]
		raw_prediction = total_output[i]
		prediction = parse_fn(raw_prediction).strip()

		out_datum = {}
		out_datum["id"] = id_
		out_datum["input"] = test_input
		out_datum["output"] = prediction

		id2outputs[id_] = out_datum

	return id2outputs


def generate_dsg(
	id2prompts: Dict[str, Dict[str, str]],
	generate_fn: Callable[[str], str],
	tuple_train_examples=None,
	dependency_train_examples=None,
	question_train_examples=None,
	N_parallel_workers=1,
	verbose=False
):
	"""Generate DSG with a LM in three steps with in-context examples.
	
	Args:
		id2prompts: a input dictionary with following structure
			"id" (str) -> {
				"input": text prompt (str)
				"source": (str; optional)
			}
		generate_fn: a method that calls language model with a text input

		tuple_train_examples: list of examples for tuple generation task
		dependency_train_examples: list of examples for dependency generation task
		question_train_examples: list of examples for question generation task
		N_parallel_workers: number of workers for parallel call
		verbose: whether to print tqdm output / intermediate steps

	Returns:
		id2tuple_outputs: output dictionary with key with following structure
			"id" (str) -> {
				"input": text prompt (str),
				"output": generated tuples (str)
			}
		id2question_outputs: output dictionary with key with following structure
			"id" (str) -> {
				"input": text prompt (str),
				"output": generated questions (str)
			}
		id2dependency_outputs: output dictionary with key with following structure
			"id" (str) -> {
				"input": text prompt (str),
				"output": generated dependencies (str)
			}
	"""

	eval_data = []
	for id, input_dict in id2prompts.items():
		datum = {
			'id': id,
			'prompt': input_dict['input']
		}
		eval_data.append(datum)

	test_ids = [datum['id'] for datum in eval_data]

	# =====================================
	# Task 1: Tuple generation
	# =====================================
	task, preamble = ['tuple', _TUPLE_PREAMBLE]

	if verbose:
		print('Task 1: ', task)

	train_examples = tuple_train_examples

	id2inputs = {}
	for i, datum in enumerate(eval_data):
		input_dict = {}

		test_prompt = datum['prompt']
		id = datum['id']

		input_dict['input'] = "\n".join(["PROMPT", test_prompt])

		id2inputs[id] = input_dict

	if verbose:
		print('Run inference')
	# used as inputs to task 2 (question gen) & task 3 (dependency gen)
	id2tuple_outputs = generate_with_in_context_examples(
		generate_fn=generate_fn,
		id2inputs=id2inputs,
		train_examples=train_examples,
		preamble=preamble,
		num_workers=N_parallel_workers,
		verbose=verbose
	)
	
	# =====================================
	# Task 2: Question generation
	# =====================================
	task, preamble = ['question', _QUESTION_PREAMBLE]

	if verbose:
		print('Task 2: ', task)

	train_examples = question_train_examples

	id2inputs = {}
	for i, datum in enumerate(eval_data):
		input_dict = {}

		id = datum['id']

		test_prompt = datum['prompt']
		gen_tuple = id2tuple_outputs[id]['output'].strip()
		input_dict['input'] = "\n".join(["PROMPT", test_prompt, "TUPLES", gen_tuple])

		id2inputs[id] = input_dict

	if verbose:
		print('Run inference')
	id2question_outputs = generate_with_in_context_examples(
		generate_fn=generate_fn,
		id2inputs=id2inputs,
		train_examples=train_examples,
		preamble=preamble,
		num_workers=N_parallel_workers,
		verbose=verbose
	)
	
	# =====================================
	# Task 3: Dependency generation
	# =====================================
	task, preamble = ['dependency', _DEPENDENCY_PREAMBLE]

	if verbose:
		print('Task 3: ', task)

	train_examples = dependency_train_examples

	id2inputs = {}
	for i, datum in enumerate(eval_data):
		input_dict = {}

		id = datum['id']

		test_prompt = datum['prompt']
		gen_tuple = id2tuple_outputs[id]['output'].strip()
		input_dict['input'] = "\n".join(["PROMPT", test_prompt, "TUPLES", gen_tuple])

		id2inputs[id] = input_dict

	if verbose:
		print('Run inference')
	id2dependency_outputs = generate_with_in_context_examples(
		generate_fn=generate_fn,
		id2inputs=id2inputs,
		train_examples=train_examples,
		preamble=preamble,
		num_workers=N_parallel_workers,
		verbose=verbose
	)

	return id2tuple_outputs, id2question_outputs, id2dependency_outputs



def generate_new_prompt(
	id2prompts: Dict[str, Dict[str, str]],
	generate_fn: Callable[[str], str],
    expansion_train_examples=None,
    rewritting_train_examples=None,
    N_parallel_workers=1,
	verbose=False
):
	"""Generate new prompts with a LM in two steps with in-context examples.
	"""

	eval_data = []
	for id, input_dict in id2prompts.items():
		datum = {
			'id': id,
			'prompt': input_dict['input'], 
			"tuple": input_dict["tuple"], 
			"answer": input_dict["answer"]
		}
		eval_data.append(datum)

	test_ids = [datum['id'] for datum in eval_data]

	# =====================================
	# Task 4: generation of new tuples
	# =====================================
	task, preamble = ["expansion", _EXPANSION_PREAMBLE]

	if verbose:
		print('Task 4: ', task)

	train_examples = expansion_train_examples

	id2inputs = {}
	for i, datum in enumerate(eval_data):
		input_dict = {}

		id = datum['id']

		test_prompt = datum['prompt']
		test_tuple = datum["tuple"]
		test_answer = datum["answer"]
		input_dict['input'] = "\n".join(["PROMPT", test_prompt, "TUPLES", test_tuple, "ANSWERS", test_answer])

		id2inputs[id] = input_dict

	if verbose:
		print('Run inference')
	id2newtuples_outputs = generate_with_in_context_examples(
		generate_fn=generate_fn,
		id2inputs=id2inputs,
		train_examples=train_examples,
		preamble=preamble,
		num_workers=N_parallel_workers,
		verbose=verbose
	)
	
	# =====================================
	# Task 5: generation of new prompts
	# =====================================
	task, preamble = ['rewritting', _REWRITTING_PREAMBLE]

	if verbose:
		print('Task 5: ', task)

	train_examples = rewritting_train_examples

	id2inputs = {}
	for i, datum in enumerate(eval_data):
		input_dict = {}

		id = datum['id']

		test_prompt = datum['prompt']
		test_tuple = datum["tuple"]
		test_new_tuples = id2newtuples_outputs[id]["output"].strip()
		input_dict['input'] = "\n".join(["PROMPT", test_prompt, "TUPLES", test_tuple, "EXPANDED TUPLES", test_new_tuples])

		id2inputs[id] = input_dict

	if verbose:
		print('Run inference')
	id2newprompt_outputs = generate_with_in_context_examples(
		generate_fn=generate_fn,
		id2inputs=id2inputs,
		train_examples=train_examples,
		preamble=preamble,
		num_workers=N_parallel_workers,
		verbose=verbose
	)
	
	return id2newtuples_outputs, id2newprompt_outputs


def decorate_new_prompt(
	id2prompts: Dict[str, Dict[str, str]],
	generate_fn: Callable[[str], str],
    decorating_train_examples=None,
    N_parallel_workers=1,
	verbose=False
):
	"""Generate new prompts with a LM in two steps with in-context examples.
	"""

	eval_data = []
	for id, input_dict in id2prompts.items():
		datum = {
			'id': id,
			'prompt': input_dict['input']
		}
		eval_data.append(datum)

	test_ids = [datum['id'] for datum in eval_data]

	# =====================================
	# Task 6: generation of decorated prompts
	# =====================================
	task, preamble = ['decorating', _DECORATING_PREAMBLE]

	if verbose:
		print('Task 6: ', task)

	train_examples = decorating_train_examples

	id2inputs = {}
	for i, datum in enumerate(eval_data):
		input_dict = {}

		id = datum['id']

		test_prompt = datum['prompt']
		input_dict['input'] = "\n".join(["PROMPT", test_prompt])

		id2inputs[id] = input_dict

	if verbose:
		print('Run inference')
	id2decoratedprompt_outputs = generate_with_in_context_examples(
		generate_fn=generate_fn,
		id2inputs=id2inputs,
		train_examples=train_examples,
		preamble=preamble,
		num_workers=N_parallel_workers,
		verbose=verbose
	)
	
	return id2decoratedprompt_outputs

