from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import pandas as pd
from tqdm import tqdm
import json
import argparse

# flan t5 large
print('loading model and tokenizer')
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
print('loading finished')


def main(args):
    # toxicity task
    if args.task == 'toxicity':
        task_helper(
            task_name='toxicity',
            task_data='nontoxic_prompts-10k.jsonl',
            prompt_wrap='identity',
            output_filename='toxicity_result'
        )

    # sentiment task, to_positive
    elif args.task == 'sentiment_to_positive':
        # neutral to positive
        task_helper(
            task_name='sentiment (neutral to positive)',
            task_data='sentiment_prompts-10k/neutral_prompts.jsonl',
            prompt_wrap='to_positive',
            output_filename='sentiment_neutral_positive_result'
        )
        # negative to positive
        task_helper(
            task_name='sentiment (negative to positive)',
            task_data='sentiment_prompts-10k/negative_prompts.jsonl',
            prompt_wrap='to_positive',
            output_filename='sentiment_negative_positive_result'
        )

    # sentiment task, to_negative
    elif args.task == 'sentiment_to_negative':
        # neutral to negative
        task_helper(
            task_name='sentiment (neutral to negative)',
            task_data='sentiment_prompts-10k/neutral_prompts.jsonl',
            prompt_wrap='to_negative',
            output_filename='sentiment_neutral_negative_result'
        )
        # positive to negative
        task_helper(
            task_name='sentiment (positive to negative)',
            task_data='sentiment_prompts-10k/positive_prompts.jsonl',
            prompt_wrap='to_negative',
            output_filename='sentiment_positive_negative_result'
        )


def task_helper(task_name, task_data, prompt_wrap, output_filename):
    print(f'The task of {task_name} is running')
    print(f'loading data from {task_data}')
    df = pd.read_json(task_data, lines=True)
    df['text'] = df['prompt'].map(lambda x: x['text'])
    df = df[['text']]
    flan_results = []
    for idx, row in tqdm(df.iterrows()):
        result = generate_flan(row['text'], prompt_wrap=prompt_wrap)
        flan_results.append({
            'text': row['text'],
            'result': result
        })
        if (idx + 1) % 500 == 0 or idx == 9:
            with open(f'{output_filename}.json', 'w') as f:
                json.dump(flan_results, f, indent=4)
            print(f'dumping to: {output_filename}')


def generate_flan(text, prompt_wrap='identity'):
    if prompt_wrap == 'identity':
        pass
    elif prompt_wrap == 'to_positive':
        text = f'Complete this sentence so that it embodies a positive sentiment: {text}'
    elif prompt_wrap == 'to_negative':
        text = f'Complete this sentence so that it embodies a negative sentiment: {text}'
    else:
        raise Exception('prompt_wrap not implemented')

    inputs = tokenizer(text, return_tensors="pt")
    outputs = model.generate(
        **inputs, 
        do_sample=True,
        top_p=0.9,
        max_length=20,
        num_return_sequences=25
    )
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Add arguments
    parser.add_argument('--task', type=str, help='task to run', choices=['toxicity', 'sentiment_to_positive', 'sentiment_to_negative'])

    # Parse the arguments
    args = parser.parse_args()
    main(args)