from ast import arg
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import random
import string
import os
import sys
sys.path.append('./model_api')
from BaseAPI import BaseAPI


class GLM4API(BaseAPI):

    def __init__(self, model_path, tokenizer_path, generation_config={}):
        super().__init__(generation_config)
        self.model_path = model_path
        self.tokenizer_path = tokenizer_path
        self.device = torch.device('cuda')
        print('loading model...')
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(self.device).eval()
        self.tokenizer.padding_side = 'left'
        print('finish loading')
        self.sys_prompt = self.basic_sys_prompt
        self.no_tool_sys_prompt = """You are a helpful assistant"""

    def get_system_prompt(self, d=None):
        if d is None:
            return self.sys_prompt
        else:
            # check if no tool is provided
            if not d['environments'] or 'name' not in d['environments'][0] or d['environments'][0]['name'] == '':
                return self.no_tool_sys_prompt
            else:
                return self.sys_prompt

    def parse_tool_str(self, tool_str):
        try:
            res = {}
            if isinstance(tool_str, list):
                res = tool_str[0]
            else:
                if '[' not in tool_str or ']' not in tool_str:
                    return {'type': 'error', 'message': f'Wrong tool call result here1: {tool_str}'}
                res = self.parse_json(tool_str)[0]
            if 'name' not in res or 'arguments' not in res:
                print(f"Wrong tool call result: {res}")
                return {'type': 'error', 'message': f'Wrong tool call result: {res}'}
            tool_call_id = ''.join(random.sample(string.ascii_letters + string.digits, 9))
            tool_name = res['name']
            arguments = self.parse_json(res['arguments'].replace('\n', '\\n')) if isinstance(res['arguments'], str) else res['arguments']
            return {'type': 'tool', 'tool_call_id': tool_call_id, 'tool_name': tool_name, 'arguments': arguments}
        except Exception:
            return {'type': 'error', 'message': f'Wrong tool call result here2: {tool_str}'}

    def response(self, messages, tools):
        if str(tools) not in messages[0]['content']:
            messages[0]['content'] = f"{messages[0]['content']}\n\n{tools}"
        ipt_messages = []
        for msg in messages:
            if msg['role'] == 'assistant' and 'tool_calls' in msg:
                msg['content'] = [tool_call['function'] for tool_call in msg['tool_calls']]
                msg.pop('tool_calls')
            ipt_messages.append(msg)
        inputs = self.tokenizer.apply_chat_template(ipt_messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.device)
        out = self.model.generate(**inputs, pad_token_id=self.tokenizer.eos_token_id, **self.generation_config)
        return self.tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)

    def generate_response(self, messages, tools):
        completion = self.response(messages, tools)
        completion = completion.replace('```json', '').replace('```', '').strip('\n')

        ## tool call part
        if self.is_json(completion):
            completion_new = self.parse_json(completion)
            if isinstance(completion_new, dict) and 'name' in completion_new:
                completion_new = [completion_new]
            if not isinstance(completion_new, list) or not isinstance(completion_new[0], dict) or 'name' not in completion_new[0]:
                return {'type': 'content', 'content': completion}
            completion = completion_new
            res = self.parse_tool_str(completion)
            return res

        ## normal content part
        elif '[' in completion and ']' in completion:
            start = completion.index('[')
            end = completion.rindex(']')
            tool_str = completion[start:end + 1]
            # tool_str = tool_str.replace('"{', '{').replace('}"', '}')
            if self.is_json(tool_str) and isinstance(self.parse_json(tool_str), list) and isinstance(self.parse_json(tool_str)[0], dict) and 'name' in self.parse_json(tool_str)[0]:
                res = self.parse_tool_str(tool_str)
                return res
            else:
                return {'type': 'content', 'content': completion}
        elif '\n{' in completion and '}' in completion:
            try:
                name, arguments = completion.split('\n')
                if self.is_json(arguments):
                    res = self.parse_tool_str([{'name': name, 'arguments': arguments}])
                    return res
                else:
                    return {'type': 'content', 'content': completion}
            except:
                return {'type': 'content', 'content': completion}
        else:
            return {'type': 'content', 'content': completion}
