
import os
import random

import fire
import pandas as pd
from termcolor import colored

def main(
    input_filename: str = "./outputs/transfer_pick/styletransfer_politics_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=Trueaverage_style=1.0_content=1.0_clean=False.jsonl",
    seed: int = 43,
):
    random.seed(seed)
    df = pd.read_json(input_filename, lines=True)
    # df = df[(df.original_author == "supernamekianpenis") & (df.target_author == "IndianBrit")]

    # formal-ish to informal
    # df = df[(df.content_subreddit == "wallstreetbets") & (df.reference_subreddit == "CasualUK")]
    # df = df[(df.content_subreddit == "australia") & (df.reference_subreddit == "CasualUK")]
    
    indices = list(range(len(df)))
    random.shuffle(indices)

    for index in indices:
        os.system("clear")
        row = df.iloc[index]
        # print(colored("Target Author: ", "green"), colored(row["target_author"], "yellow"))
        print(colored("Target Author: ", "green"), colored(row["reference_author"], "yellow"))
        for j, reference in enumerate(row["reference"]):
            print(colored(j, "cyan"), ">", reference)
            print("="*50)
        for j, (original, transfer) in enumerate(zip(row["content_text"][:4], row["transfer_pick"][:4])):
            print(colored(j, "cyan"), colored(original, "green"), ">", colored(transfer, "yellow"))

        import pdb; pdb.set_trace()
        print(row["paraphrase_content_text"][0][0])
            
        input()
        
    return 0

if __name__ == "__main__":
    fire.Fire(main)
