# -*- coding: utf-8 -*-

import os
import re
import json
import torch
import pickle
import argparse
from fastapi import FastAPI
from pydantic import BaseModel, conbytes
import base64
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
import uvicorn

parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--model_name_or_path", type=str,
                    default=None)
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()

print(args)

print("Current loaded model:", args.model_name_or_path.split("/")[-2])

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)

class MyDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, example_list):
        texts = []
        for example in example_list:
            question = example["query"]

            result = [seg.strip() for seg in re.split(r'(<image>)', question) if seg.strip()]
            image_cnt, content_list = 0, []
            for seg in result:
                if seg == "<image>":
                    content_list.append({"image": f"./{args.port}_{image_cnt}.png"})
                    image_cnt += 1
                else:
                    content_list.append({"text": seg})

            messages = content_list
            query = tokenizer.from_list_format(messages)
            texts.append(query)

        return texts

model = AutoModelForCausalLM.from_pretrained(
    args.model_name_or_path,
    torch_dtype=torch.float16,
    trust_remote_code=True
).eval().to(args.device)


data_collator = MyDataCollator(tokenizer)


action_set = {"MoveAhead", "RotateRight",
            "RotateLeft", "LookUp", "Done"}
a2i = {"moveahead": 0, "rotateright": 1,
           "rotateleft": 2, "lookup": 3, "done": 4}
lower_action_list = {"moveahead": "MoveAhead", "rotateright": "RotateRight",
       "rotateleft": "RotateLeft", "done": "Done"}

def map_action_to_id(action_list):
    id_list = [a2i.get(action.lower(), -1) for action in action_list]
    return id_list


def parse_action(output_str):
    output_str = output_str.strip()
    if output_str.startswith("Assistant:"):
        output_str = output_str[len("Assistant:"):].strip()
    try:
        output_str = output_str[output_str.index("3)") + 2:].strip()
    except Exception as e:
        pass
    lower_action = output_str.replace(" ", "").lower()
    output_str = output_str.split()
    for seg in output_str:
        if seg in action_set:
            return seg
    for la, a in lower_action_list.items():
        if la in lower_action:
            return a
    return "None"

# input structure
class InputData(BaseModel):
    id: str
    query: str
    image: str

class OutputPrediction(BaseModel):
    generated_action: str
    generated_text: str


app = FastAPI()
@app.post("/predict")
def predict(example: InputData):
    example = example.dict()
    image_list_bin = base64.b64decode(example["image"])
    image_list = pickle.loads(image_list_bin)
    example["image"] = image_list
    for idx, image in enumerate(image_list):
        image.save(f"./{args.port}_{idx}.png")
    print(example)
    query = data_collator([example])[0]
    print(query)
    with torch.no_grad():
        response, history = model.chat(tokenizer, query=query, history=None, max_new_tokens=128, do_sample=True, min_new_tokens=3,
                                       )
    generated_text = response
    generated_action = parse_action(generated_text)
    return {"action": generated_action, "text": generated_text}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=args.port)
