from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import random
import string
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
import sys
sys.path.append('./model_api')
from BaseAPI import BaseAPI


class Llama3API(BaseAPI):
    def __init__(self, model_path, tokenizer_path, generation_config={}):
        super().__init__(generation_config)
        self.model_path = model_path
        self.tokenizer_path = tokenizer_path
        self.device = torch.device('cuda')
        print('loading model...')
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(self.device).eval()
        self.tokenizer.padding_side = 'left'
        print('finish loading')
        self.sys_prompt = self.basic_sys_prompt
        self.no_tool_sys_prompt = """You are a helpful assistant"""

    def get_system_prompt(self, d=None):
        if d is None:
            return self.sys_prompt
        else:
            # check if no tool is provided
            if not d['environments'] or 'name' not in d['environments'][0] or d['environments'][0]['name'] == '':
                return self.no_tool_sys_prompt
            else:
                return self.sys_prompt

    def parse_tool_str(self, tool_str):
        # print(f'tool_str: {tool_str}, is_json: {self.is_json(tool_str)}')
        try:
            if '{' not in tool_str or '}' not in tool_str:
                return {'type': 'error', 'message': f'Wrong tool call result: {tool_str}'}
            res = json.loads(tool_str)
            if 'name' not in res or ('arguments' not in res and 'parameters' not in res):
                print(f"Wrong tool call result: {res}")
                return {'type': 'error', 'message': f'Wrong tool call result: {res}'}
            tool_call_id = ''.join(random.sample(string.ascii_letters + string.digits, 9))
            tool_name = res['name']
            if 'parameters' in res:
                res['arguments'] = res['parameters']
            arguments = json.loads(res['arguments']) if isinstance(res['arguments'], str) else res['arguments']
            return {'type': 'tool', 'tool_call_id': tool_call_id, 'tool_name': tool_name, 'arguments': arguments}
        except Exception:
            return {'type': 'error', 'message': f'Wrong tool call result: {tool_str}'}

    def response(self, messages, tools):
        if tools:
            inputs = self.tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.device)
        else:
            inputs = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
            ).to(self.device)
        # print(f'tools: {tools}\ninput_prompt: {self.tokenizer.decode(inputs["input_ids"][0])}')
        out = self.model.generate(**inputs, pad_token_id=self.tokenizer.eos_token_id, **self.generation_config)
        return self.tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)

    def generate_response(self, messages, tools):
        completion = self.response(messages, tools)
        print(f'completion: {completion}')
        completion = completion.replace('```json', '').replace('```', '').strip('\n')

        ## tool call part
        if self.is_json(completion) and 'name' in self.parse_json(completion):
            # inputs = self.tokenizer.apply_chat_template(
            #     messages,
            #     tools=tools,
            #     add_generation_prompt=True,
            #     return_dict=True,
            #     return_tensors="pt",
            # )
            # print(f'input_prompt: {self.tokenizer.decode(inputs["input_ids"][0])}')
            res = self.parse_tool_str(completion)
            return res

        ## normal content part
        elif '{' in completion and '}' in completion:
            start = completion.index('{')
            end = completion.rindex('}')
            tool_str = completion[start:end + 1]
            tool_str = tool_str.replace("'", '"').replace('"{', '{').replace('}"', '}')
            # print('tool_str: ', tool_str)
            if self.is_json(tool_str) and 'name' in self.parse_json(tool_str):
                # inputs = self.tokenizer.apply_chat_template(
                #     messages,
                #     tools=tools,
                #     add_generation_prompt=True,
                #     return_dict=True,
                #     return_tensors="pt",
                # )
                # print(f'input_prompt: {self.tokenizer.decode(inputs["input_ids"][0])}')
                res = self.parse_tool_str(tool_str)
                return res
            else:
                return {'type': 'content', 'content': completion}
        else:
            return {'type': 'content', 'content': completion}
