import time
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import argparse
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from six.moves import cPickle as pkl
import os

'''
Reference: https://github.com/martiansideofthemoon/ai-detection-paraphrases
'''

class DipperParaphraser(object):
    def __init__(self, model="kalpeshk2011/dipper-paraphraser-xxl", verbose=True):
        time1 = time.time()
        self.tokenizer = T5Tokenizer.from_pretrained('google/t5-v1_1-xxl')

        with init_empty_weights():
            self.model = T5ForConditionalGeneration.from_pretrained(model)

        self.model.tie_weights()

        self.model.model_parallel = True
        self.model = load_checkpoint_and_dispatch(self.model, CHECKPOINT, device_map="auto", no_split_module_classes=["encoder", "decoder", "lm_head", "shared"])

        if verbose:
            print(f"{model} model loaded in {time.time() - time1}")
        self.model.cuda()
        self.model.eval()
        self.model.tie_weights()

    def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs):
        """Paraphrase a text using the DIPPER model.

        Args:
            input_text (str): The text to paraphrase. Make sure to mark the sentence to be paraphrased between <sent> and </sent> blocks, keeping space on either side.
            lex_diversity (int): The lexical diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity.
            order_diversity (int): The order diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity.
            **kwargs: Additional keyword arguments like top_p, top_k, max_length.
        """
        assert lex_diversity in [0, 20, 40, 60, 80, 100], "Lexical diversity must be one of 0, 20, 40, 60, 80, 100."
        assert order_diversity in [0, 20, 40, 60, 80, 100], "Order diversity must be one of 0, 20, 40, 60, 80, 100."

        lex_code = int(100 - lex_diversity)
        order_code = int(100 - order_diversity)

        input_text = " ".join(input_text.split())
        sentences = sent_tokenize(input_text)
        prefix = " ".join(prefix.replace("\n", " ").split())
        output_text = ""

        for sent_idx in range(0, len(sentences), sent_interval):
            curr_sent_window = " ".join(sentences[sent_idx:sent_idx + sent_interval])
            final_input_text = f"lexical = {lex_code}, order = {order_code}"
            if prefix:
                final_input_text += f" {prefix}"
            final_input_text += f" <sent> {curr_sent_window} </sent>"

            final_input = self.tokenizer([final_input_text], return_tensors="pt", truncation=True, max_length=750)
            final_input = {k: v.cuda() for k, v in final_input.items()}

            with torch.inference_mode():
                outputs = self.model.generate(**final_input, **kwargs)
            outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            prefix += " " + outputs[0]
            output_text += " " + outputs[0]

        return output_text 
    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str)
    parser.add_argument('--context_file', type=str, default='')
    parser.add_argument('--output_file', type=str)
    parser.add_argument('--L', type=int, default=60)
    parser.add_argument('--O', type=int, default=0)
    args = parser.parse_args()
    
    print(args)

    DUMP = "DIR/TO/CACHE/"
    
    with open(os.path.join(DUMP, args.input_file), "rb") as f:
        inp = pkl.load(f)
        print("loaded input:", os.path.join(DUMP, args.input_file))

    try:
        with open(os.path.join(DUMP, args.context_file), "rb") as f:
            context = pkl.load(f)
            print("loaded context:", os.path.join(DUMP, args.context_file))
    except:
        context = ["" for i in range(len(inp))]

    CHECKPOINT = "DIR/TO/CACHE/MODEL/CHECKPOINT"
    dp = DipperParaphraser(model="kalpeshk2011/dipper-paraphraser-xxl")
    out = []
    
    st = time.time()

    for i in range(len(inp)):
        
        input_text = inp[i]
        ctx = context[i]

        output = dp.paraphrase(input_text, lex_diversity=args.L, order_diversity=args.O, prefix=ctx, do_sample=True, top_p=0.75, top_k=None, max_length=750)
        out.append(output)
        
        with open("{}.txt".format(args.output_file.split(".")[0]), "a") as f:
            f.write("Input: " + input_text + "\n" + "-"*50 + "\n")
            f.write("Output: " + output + "\n" + "*"*50 + "\n")

        print("{:2d}/{:2d} Time:{:.2f} min".format(i+1, len(inp), (time.time()-st)/60), end="\r", flush=True)


    with open(os.path.join(DUMP, args.output_file), "wb") as f:
        pkl.dump(out, f)