from transformers import (
    AutoTokenizer,
    NoBadWordsLogitsProcessor
)
import datasets
from multiprocessing import Pool
#import outlines.models as models
#from outlines.text.generate.regex import Regex
import torch
import pysbd
import tqdm
import string
import numpy as np
import re
from tqdm import tqdm
from vllm import LLM, SamplingParams

model_name = "Ujjawal/llama2-paws-paraphrase" # "Vamsi/T5_Paraphrase_Paws"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)

seg = pysbd.Segmenter(language="en", clean=True)

ds = datasets.load_from_disk("imdb_horribleincredible")
#keys = [["bad", "Bad"], ["good", "Good"]]
keys = [["horrible", "Horrible"], ["incredible"]]#"Incredible"]]
keysflat = [k for l in keys for k in l]


def covered(s, words): # does string s contain any element of words?
    s = s.translate(str.maketrans('', '', string.punctuation))
    s = s.lower().strip()
    for i,k in enumerate(words):
        #if ' ' + k in s or k + ' ' in s:
        #if f' {k} ' in s or s.startswith(k) or s.endswith(k):
        if re.search(r'\b%s\b' % re.escape(k), s):
            return True
    return False


def get_tokens_as_list(word_list):
    "Converts a sequence of words into a list of tokens"
    tokens_list = []
    for word in word_list:
        tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
        tokens_list.append(tokenized_word)
    return tokens_list

#keys = [["bad","Bad"], ["good", "Good",]]
#keysflat = [k for l in keys for k in l]
bad_words_ids = get_tokens_as_list(word_list=keysflat)
bad_words_ids_firsttok = [[t[0]] for t in bad_words_ids]
processor = NoBadWordsLogitsProcessor(bad_words_ids_firsttok, tokenizer.eos_token_id)

template = """Paraphrase the following text.
Text to paraphrase: {text}
Paraphrase:"""

sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    logits_processors=[processor],
    max_tokens=100,
    stop="\n",
    spaces_between_special_tokens=False,
    n=1
)

llm = LLM(model=model_name)
raw2 = 'Oh my god what a story! This movie is very good and it had to be God who had this happen! You did a awesome job.The acting was really good you picked the right actors for sure. This movie is so good I am really glad you made this because if you had not then I would have never ever known about this story because I am not a big golf fan and I think it is kinda boring so thank you. I really enjoyed it and that is why I gave the movie a 10/10. I liked Shia Labouf too he was perfect for the roll of Fransis Quimet. I hope most of that stuff you put in there was true also. Oh and some parts were funny and others I was just really happy.'

# S --> T
for k in ds:
    if not k.startswith('S'):
        continue

    print(k)
    dsk = ds[k]
    if 'Tpara1' in dsk.column_names:
        continue

    gold_label = int(k.split('S')[1].split('_')[0])
    bad_words_ids = get_tokens_as_list(word_list=keysflat)

    # this ordering is more confusing but it has ~30x higher throughput
    # compared to calling .generate() once per example

    to_paraphrase = []
    sentence_counter = 0
    paraphrased_ids = set()
    for ex in tqdm(dsk):
        text_raw = ex['text']
        try:
            sentences = seg.segment(text_raw)
        except:
            sentences = seg.segment(raw2)

        ids = [i for (i,s) in enumerate(sentences) if covered(s, keysflat)]
        paraphrased_ids.update([sentence_counter + i for i in ids])
        to_paraphrase.extend([sentences[i] for i in ids])
        sentence_counter += len(sentences)

    prompts = [template.format(text=p) for p in to_paraphrase]

    outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
    paraphrases = []
    for output in outputs:
        generated_text = output.outputs[0].text # here you could filter if you had n>1
        paraphrases.append(generated_text.strip())

    # now recombine
    sentence_counter = 0
    paraphrase_counter = 0
    finished = []
    for ex in tqdm(dsk):
        text_raw = ex['text']
        try:
            sentences = seg.segment(text_raw)
        except:
            sentences = seg.segment(raw2)
        new_sentences = []
        for i,s in enumerate(sentences):
            id = sentence_counter + i
            if id in paraphrased_ids:
                new_sentences.append(paraphrases[paraphrase_counter])
                paraphrase_counter += 1
            else:
                new_sentences.append(s)
        finished.append(" ".join(new_sentences))
        sentence_counter += len(sentences)

    ds[k] = ds[k].add_column(f"Tpara1", finished)
    ds.save_to_disk("imdb_horribleincredible_withpara")

breakpoint()
ds.save_to_disk("imdb_horribleincredible_withpara")


template = """Paraphrase the following text without using the word "{exclude}". Try to include the word "{include}" in your paraphrase.
Text to paraphrase: {text}
Paraphrase:"""
"""

model = models.transformers(model_name, device="cuda") # can this be a smaller model?
llm = LLM(model=model_name)
# Sigood/bad --> Sibad/good
for k in ds:
    if not k.startswith('S'):
        continue

    print(k)
    dsk = ds[k].select(range(200))

    gold_label = int(k.split('S')[1].split('_')[0])
    bad_words_ids = get_tokens_as_list(word_list=keys[gold_label])
    force_words = [k for k in keysflat if k not in keys[gold_label]]

    nobad_processor = NoBadWordsLogitsProcessor(bad_words_ids, tokenizer.eos_token_id)

    regex_processor = Regex(
        model,
        regex_string=r".*(?:good).*"
    )

    def create_proposal(token_ids, logits):
        token_ids = torch.Tensor([token_ids]).long().to("cuda")
        return regex_processor.create_proposal(token_ids, logits)

    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        logits_processors=[nobad_processor, create_proposal],
        max_tokens=100,
        stop="\n",
        spaces_between_special_tokens=False,
        n=50
    )

    Sigood = dsk.filter(lambda ex: ex['weak_label'] == ex['label'])
    Sibad = dsk.filter(lambda ex: ex['weak_label'] != ex['label'])

    # this ordering is more confusing but it has ~30x higher throughput
    # compared to calling .generate() once per example

    to_paraphrase = []
    sentence_counter = 0
    paraphrased_ids = set()
    for ex in tqdm(Sigood):
        text_raw = ex['text']
        sentences = seg.segment(text_raw)
        ids = [i for (i,s) in enumerate(sentences) if covered(s,keys[gold_label])]
        paraphrased_ids.update([sentence_counter + i for i in ids])
        to_paraphrase.extend([sentences[i] for i in ids])
        sentence_counter += len(sentences)

    prompts = [template.format(text=p, include=force_words[0], exclude=keys[gold_label][0]) for p in to_paraphrase]

    outputs = llm.generate(prompts, sampling_params, use_tqdm=True)

    paraphrases = []
    failed = 0
    for i,output in enumerate(outputs):
        good_outputs = []
        for output in output.outputs:
            #if covered(output.text, force_words) and not covered(output.text, keys[gold_label]):
            #    good_outputs.append((output.text, output.cumulative_logprob))

            # todo: just enforce that *at least one* sentence has the opposite word
            # might want to break this down into two steps
            if not covered(output.text, keys[gold_label]):
                good_outputs.append((output.text, output.cumulative_logprob))
        if len(good_outputs) == 0:
            # todo: keep resampling here
            good_outputs.append(('',-1))
            failed += 1

        good_outputs.sort(key=lambda x: x[1])
        chosen = good_outputs[-1]
        paraphrases.append(chosen[0].strip())

    # now recombine
    sentence_counter = 0
    paraphrase_counter = 0
    finished = []
    for ex in tqdm(Sigood):
        text_raw = ex['text']
        sentences = seg.segment(text_raw) # this relies on the segmenter being deterministic
        new_sentences = []
        for i,s in enumerate(sentences):
            id = sentence_counter + i
            if id in paraphrased_ids:
                new_sentences.append(paraphrases[paraphrase_counter])
                paraphrase_counter += 1
            else:
                new_sentences.append(s)
        finished.append(" ".join(new_sentences))
        sentence_counter += len(sentences)

    breakpoint()
"""
