# -*- coding: utf-8 -*-

import os
import re
import json
import torch
import pickle
import argparse
from fastapi import FastAPI
from pydantic import BaseModel, conbytes
import base64
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
import uvicorn

parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--model_name_or_path", type=str,
                    default=None)
parser.add_argument("--port", type=int, default=8080)
parser.add_argument("--full_path", type=int, default=1)
args = parser.parse_args()
full_model_path = ""
if args.full_path:
    args.model_name_or_path = full_model_path.format(args.model_name_or_path)

print(args)

print("Current loaded model:", args.model_name_or_path.split("/")[-1])

tokenizer = AutoTokenizer.from_pretrained(
    args.model_name_or_path
)
    
class GemmaDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, example_list):
        texts = []
        for example in example_list:
            prompt = example["query"]
            prompt = [{"role": "user", "content": prompt}]
            prompt = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
            print(prompt)
            texts.append(prompt)
        batch = tokenizer(texts, return_tensors="pt")
        return batch

model = AutoModelForCausalLM.from_pretrained(
    args.model_name_or_path,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",  # Only available on A100 or H100
).to(args.device)

if tokenizer.pad_token is None:
    tokenizer.pad_token, tokenizer.cls_token = tokenizer.eos_token, tokenizer.eos_token
    tokenizer.sep_token, tokenizer.mask_token = tokenizer.eos_token, tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id
    model.generation_config.pad_token_id = tokenizer.pad_token_id

data_collator = GemmaDataCollator(tokenizer)


action_set = {"MoveAhead", "RotateRight",
            "RotateLeft", "LookUp", "Done"}
a2i = {"moveahead": 0, "rotateright": 1,
           "rotateleft": 2, "lookup": 3, "done": 4}

def map_action_to_id(action_list):
    id_list = [a2i.get(action.lower(), -1) for action in action_list]
    return id_list


def parse_action(output_str):
    output_str = output_str.strip()
    if output_str.startswith("ASSISTANT:"):
        output_str = output_str[len("ASSISTANT:"):].strip()
    try:
        output_str = output_str[output_str.index("3)") + 2:].strip()
    except Exception as e:
        pass
    for seg in output_str:
        for action in action_set:
            if action in seg:
                return action
    output_str = output_str.split()
    for seg in output_str:
        if seg in action_set:
            return seg
    return "None"

# input structure
class InputData(BaseModel):
    id: str
    query: str
    image: str

class OutputPrediction(BaseModel):
    generated_action: str
    generated_text: str


app = FastAPI()
@app.post("/predict")
def predict(example: InputData):
    example = example.dict()
    image_list_bin = base64.b64decode(example["image"])
    image_list = pickle.loads(image_list_bin)
    example["image"] = image_list
    # example.image = [str(type(img)) for img in example.image]
    batch = data_collator([example])
    batch = {k: v.to(args.device) for k, v in batch.items()}
    with torch.no_grad():
        generated_ids = model.generate(**batch, max_new_tokens=256, min_new_tokens=3,
                                       tokenizer=tokenizer, stop_strings=["<end_of_utterance>"])
    generated_text = tokenizer.batch_decode(generated_ids[:, batch["input_ids"].size(1):], skip_special_tokens=True)
    generated_text = generated_text[0]
    # generated_text = "\nAssistant: MoveAhead  MoveAhead "
    generated_action = parse_action(generated_text)
    return {"action": generated_action, "text": generated_text}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=args.port)
