# %%
import spacy
import torch

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from SGM.utils import format_scene_graph

class TextGraphParsor:
    def __init__(self, device='cuda:0', max_input_length=64, max_output_length=64, beam_size = 1, lemma=False, lowercase=False, parser_checkpoint='lizhuang144/flan-t5-base-factual-sg', text_encoder=None, multi_modal_encoder_checkpoint=None):
        '''
        :param device: gpu or cpu
        :param max_length: max length of input text
        :param parser_checkpoint: model checkpoint for parser
        :param text_encoder_checkpoint: model checkpoint for text encoder
        :param multi_modal_encoder_checkpoint: model checkpoint for multi-modal encoder
        '''
        # the default models
        # Set up parsing model
        self.device = device
        self.max_input_length = max_input_length
        self.parser_tokenizer = AutoTokenizer.from_pretrained(parser_checkpoint)
        self.parser = AutoModelForSeq2SeqLM.from_pretrained(parser_checkpoint)
        self.parser.eval()
        self.parser.to(device)

        self.lemma = lemma
        self.lowercase = lowercase
        if self.lemma:
            # please download the en_core_web_sm first
            self.lemmatizer = spacy.load("en_core_web_sm")

        self.max_output_length = max_output_length
        self.beam_size = beam_size

    def parse(self, text_input,max_input_length=128, max_output_length=128, beam_size = 1,lowercase=False, lemma=False):
        '''
        :param text_input: one or a list of textual image descriptions
        :return: corresponding scene graphs of the input descriptions
        '''

        if isinstance(text_input, str):
            text_input = [text_input]

        if lowercase:
            text_input = [text.lower() for text in text_input]

        if lemma:
            text_input = [' '.join([token.lemma_ for token in self.lemmatizer(text)]) for text in text_input]

        #breakpoint()
        text_input = ['Generate Scene Graph: ' + text for text in text_input]
        with torch.no_grad():
            encoded_text = self.parser_tokenizer(
                text_input,
                max_length=max_input_length,
                truncation=True,
                padding=True,
                return_tensors='pt')
            text_tokens = encoded_text['input_ids'].to(self.device)
            text_mask = encoded_text['attention_mask'].to(self.device)

            generated_ids = self.parser.generate(
                text_tokens,
                attention_mask=text_mask,
                use_cache=True,
                decoder_start_token_id=self.parser_tokenizer.pad_token_id,
                num_beams=beam_size,
                max_length=max_output_length,
                early_stopping=True
            )

            # output to text
            output_text = self.parser_tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            output_text = [format_scene_graph(text.replace('Generate Scene Graph:','').strip()) for text in output_text]
            return output_text




