from typing import Dict, List, Optional
import pandas as pd

# task classes
_TASK_NAMES = [
	"tuple",
	"dependency",
	"question",
	"expansion", 
	"rewritting", 
	"decorating"
]

# examples ids
_TIFA160_EXAMPLES_IDS = [
    'coco_361740',
    'drawbench_155',
    'partiprompt_86',
    'paintskill_374',
    'coco_552592',
    'partiprompt_1414',
    'coco_627537',
    'coco_744388',
    'partiprompt_1108',
    'coco_397109',
    'coco_666114',
    'coco_62896',
    'paintskill_235',
    'drawbench_159',
    'partiprompt_893',
    'coco_322041',
    'coco_292534',
    'drawbench_57',
    'partiprompt_555',
    'coco_488166',
    'partiprompt_726',
    'coco_323167',
    'coco_625027',
]

_TRAIN_EXAMPLES = pd.read_csv("dsg/data/tifa-examples.csv")


def create_train_example(
	prompt: str,
	task: str = "tuple",
	tuples: Optional[List[str]] = None,
	dependencies: Optional[List[str]] = None,
	questions: Optional[List[str]] = None,
	answer: Optional[List[str]] = None,
	expansion_tuples: Optional[List[str]] = None,
	rewritten_prompt: Optional[str] = None
) -> Dict[str, str]:
	""" Create a training (shown in-context) example for tuple/dependency/question generation tasks.

	Tasks (one of _TASK_NAMES):
	tuple generation: prompt -> tuples
	dependency generation: prompt + tuples -> dependencies
	question generation: prompt + tuples -> questions
    expansion generation: prompt + tuples + answer -> expanded tuples
    rewritting generation: prompt + tuples + expanded tuples -> rewritten prompt

	Args:
	prompt: input text prompt
	task: one of pre-defined tasks in _TASK_NAMES
	tuples: list of semantic tuples to create evaluation queries
	dependencies: list of dependencies between evaluation queries
	questions: list of natural language queries
	answer: list of answers for questions above
	expansion_tuples: list of expanded tuples
	rewritten_prompt: rewritten prompt

	Returns:
	{
		"input": str - text prompt
		"output": str - task-specific target output
	}
	"""

	# task should be one of the pre-defined tasks
	# (tuple generation / dependency generation / question generation)
	assert task in _TASK_NAMES, f"task == {task}"

	inputs = []
	outputs = []
	n_outputs = len(tuples)

	# Task 1 - tuple generation: prompt -> tuples
	if task == "tuple":
		inputs += ["PROMPT"]
		inputs += [prompt]

		for i in range(n_outputs):
			output = f"{i+1} | {tuples[i]}"
			output = " ".join(output.split())  # remove double whitespaces if any
			outputs += [output]

	# Task 2 - dependency generation: prompt + tuples -> dependencies
	elif task == "dependency":
		inputs += ["PROMPT"]
		inputs += [prompt]

		inputs += ["TUPLES"]
		for i in range(n_outputs):
			input_2 = f"{i+1} | {tuples[i]}"
			input_2 = " ".join(input_2.split())  # remove double whitespaces if any
			inputs += [input_2]

		outputs = []
		for i in range(n_outputs):
			output = f"{i+1} | {dependencies[i]}"
			output = " ".join(output.split())  # remove double whitespaces if any
			outputs += [output]

	# Task 3 - question generation: prompt + tuples -> natural language questions
	elif task == "question":
		inputs += ["PROMPT"]
		inputs += [prompt]

		inputs += ["TUPLES"]
		for i in range(n_outputs):
			input_2 = f"{i+1} | {tuples[i]}"
			input_2 = " ".join(input_2.split())  # remove double whitespaces if any
			inputs += [input_2]

		for i in range(n_outputs):
			output = f"{i+1} | {questions[i]}"
			output = " ".join(output.split())  # remove double whitespaces if any
			outputs += [output]
	
	# Task 4 - tuple expansion: prompt + tuples + answers -> expanded tuples
	elif task == "expansion": 
		inputs += ["PROMPT"]
		inputs += [prompt]
		inputs += ["TUPLES"]
		for i in range(n_outputs):
			input_2 = f"{i+1} | {tuples[i]}"
			input_2 = " ".join(input_2.split())  # remove double whitespaces if any
			inputs += [input_2]
		
		inputs += ["ANSWERS"]
		for i in range(n_outputs):
			input_2 = f"{i+1} | {answer[i]}"
			input_2 = " ".join(input_2.split())  # remove double whitespaces if any
			inputs += [input_2]
		
		n_outputs2 = len(expansion_tuples)
		for i in range(n_outputs2):
			output = f"{i+1} | {expansion_tuples[i]}"
			output = " ".join(output.split())  # remove double whitespaces if any
			outputs += [output]
	
	# Task 5 - rewritting prompt: prompt + tuples + expanded tuples -> rewritten prompts
	elif task == "rewritting": 
		inputs += ["PROMPT"]
		inputs += [prompt]
		inputs += ["TUPLES"]
		for i in range(n_outputs):
			input_2 = f"{i+1} | {tuples[i]}"
			input_2 = " ".join(input_2.split())  # remove double whitespaces if any
			inputs += [input_2]
		
		inputs += ["EXPANDED TUPLES"]
		n_outputs2 = len(expansion_tuples)
		for i in range(n_outputs2):
			input_2 = f"{i+1} | {expansion_tuples[i]}"
			input_2 = " ".join(input_2.split())  # remove double whitespaces if any
			inputs += [input_2]
		
		outputs += [rewritten_prompt]
	
	# Task 6 - decorating prompt: prompt -> decorated prompts
	elif task == "decorating": 
		inputs += ["PROMPT"]
		inputs += [prompt]
		
		outputs += [rewritten_prompt]

	return {
		"input": "\n".join(inputs),
		"output": "\n".join(outputs),
	}



def tifa_id2example(
	df: pd.DataFrame,
	id: str,
	task: str = "tuple",
) -> Dict[str, str]:
	"""Create a training in-context example from TIFA annotation dataframe.

	Args:
	df: pandas dataframe with columns: [item_id, text, tuple, dependency,
		question_natural_language]
	id: unique prompt id (item_id)
	task: one of pre-defined tasks: ["tuple", "dependency", "question"]

	Returns:
	{
		'input': str - text prompt
		'output': str - task-specific target output
	}
	"""

	# Reading columns (prompts, tuples, dependency, proposition id, question)
	prompt = df[df.item_id == id].text.tolist()[0]
	all_tuples = df[df.item_id == id].tuple.tolist()
	all_dependencies = df[df.item_id == id].dependency.tolist()
	all_questions = df[df.item_id == id].question_natural_language.tolist()

	### adding
	all_answers = df[df.item_id == id].answer.tolist()

	if task == "expansion": 
		rewritten_prompt = None
		all_expansion_tuples = df[df.item_id == "REWRITTEN_" + id].tuple.tolist()
	elif task == "rewritting": 
		rewritten_prompt = df[df.item_id == "REWRITTEN_" + id].text.tolist()[0]
		all_expansion_tuples = df[df.item_id == "REWRITTEN_" + id].tuple.tolist()
	elif task == "decorating": 
		prompt = df[df.item_id == "REWRITTEN_" + id].text.tolist()[0]
		rewritten_prompt = df[df.item_id == "DECORATED_" + id].text.tolist()[0]
		all_expansion_tuples = None
	else:
		rewritten_prompt = None
		all_expansion_tuples = None
	
	# Create an example
	example = create_train_example(
		prompt=prompt,
		task=task,
		tuples=all_tuples,
		dependencies=all_dependencies,
		questions=all_questions,
		answer=all_answers,
		expansion_tuples=all_expansion_tuples,
		rewritten_prompt=rewritten_prompt
	)

	return example


def get_tifa_examples(data_df=_TRAIN_EXAMPLES, ids=_TIFA160_EXAMPLES_IDS, task='tuple'):
	examples = []
	for _id in ids:
		example = tifa_id2example(data_df, _id, task=task)
		examples += [example]
	return examples

