
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from src.utils.misc.schema import load_schema

tokenizer = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-base-seq2seq")
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-base-seq2seq")

def generate(instruction, knowledge, dialog):
    if knowledge != '':
        knowledge = '[KNOWLEDGE] ' + knowledge
    dialog = ' EOS '.join(dialog)
    query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
    input_ids = tokenizer(f"{query}", return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_length=128, min_length=8, top_p=0.9, do_sample=True)
    output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return output


def get_init_dialog_context():
    schema = "flight_delay"
    schema_file = open(f'src/data/test_data/schema/{schema}.schema', "r")
    loaded_schema = load_schema(schema_file.read(), name=schema)

    # # Instruction for a chitchat task
    # instruction = f'Instruction: given a dialog context, you need to formulate a Time Series Forecasting problem. Extract Prediction window, aggregator, target attribute, filter attribute and filter operation from user.'
    # # Leave the knowldge empty
    # knowledge = str(loaded_schema)
    # dialog = [
    #     'Can you formulate a Time Series Forecasting problem for me?',
    #     'For formulating a Time Series Forecasting problem you need to give me an utterance of your forecasting task',
    #     'Can you tell me the average aircraft delay for Qatar Airlines where the tail number starts from 4500 and will start within the next week?',
    #     'It looks like your target attribute is aircraft delay, is that correct?',
    #     'yes that is correct',
    #     'Thanks for confirming, moving forward, you are using average as an aggregator',
    #     'yeah you got it correct',
    #     'thanks, and you want to use tail number as your filtering attribute?'
    #     'yes that will be used for filtering data',
    #     'Thanks, finally you want to use next week as the prediction window correct?',
    #     'yeah!',
    #     'Thank you for your patience, I will build your machine learning model right away.'
    # ]
    
    # Instruction for a chitchat task
    instruction = f'Instruction: given a dialog context, you need to formulate a Time Series Forecasting problem. Extract Prediction window, aggregator, target attribute, filter attribute and filter operation from user utterances.'
    knowledge = ''
    dialog = [
            'Can you formulate a Time Series Forecasting problem for me?',
            'For formulating a Time Series Forecasting problem you need to give me an utterance of your forecasting task',
        ]

    return instruction, knowledge, dialog

def get_response(instruction=None, knowledge=None, dialog=None):
    

    response = generate(instruction, knowledge, dialog)
    print('\nGodel:> ', response, '\n')
    return instruction, knowledge, dialog, response



instruction, knowledge, dialog = get_init_dialog_context()
uttr = 'Predict the total late aircraft delay for flights which will land in Atlanta International Airport which will have scheduled time is in between 7 Am and 11 Am in central time  and will start within next week.'
dialog.append(uttr)
print('\nUser:> ', uttr, '\n')
for i in range(15):
    
    instruction, knowledge, dialog, response = get_response(instruction, knowledge, dialog)
    dialog.append(response)
    user_input = input("User:> ")
    dialog.append(user_input)
    
    