import pandas as pd
from openai import OpenAI
import tqdm

class deepseek:
    def __init__(self):
        self.llm = OpenAI(
            base_url='https://api.deepseek.com/beta',
            api_key='sk-500b3107741d49c38f2a945082d04598'
        )

    def __call__(self, messages):
        try:
            messages[-1]["prefix"] = True
            res_str = self.llm.chat.completions.create(
                model="deepseek-chat",
                messages=messages,
                max_tokens=8192,
                # stop=['\n'],
                temperature=0,
                top_p=0.00000001,
            ).choices[0].message.content
        except Exception as e:
            res_str = "调用失败，错误信息：" + str(e)
        # res_str = res_str.split("\n")[0]
        res_str = res_str.strip()
        return res_str

# city_list = ['beijing','chengdu','chongqing','guangzhou','hangzhou','nanjing','shanghai','shenzhen','suzhou','wuhan']
city_list = ['shanghai','shenzhen','suzhou','wuhan']
prompt = """
我将要提供一个位于{}景点的名称，你需要回答这个景点是否属于[河边,江边,海边,古风写真,约会圣地,其它]中的哪一种或几种类型,如果不属于以上类型，请只回答其它, 如含有多个类型请用逗号隔开。请你仔细考虑实际情况，比如北京不存在任何江边或者海边景点，所以请不要随意回答。请回答以下问题：

Example:
景点名称: 故宫
景点类型: 古风写真
景点名称: 大梅沙
景点类型: 海边, 约会圣地
景点名称: 中国国家博物馆
景点类型: 其它
Example End

景点名称: {}
景点类型:
"""

model = deepseek()

for city in city_list:
    print(city)
    Type_count = {}
    Type_count['河边'] = 0
    Type_count['江边'] = 0
    Type_count['海边'] = 0
    Type_count['古风写真'] = 0
    Type_count['约会圣地'] = 0
    Type_count['其它'] = 0
    df = pd.read_csv(f'./{city}/attractions_tag.csv')
    # 新增ood_Type列
    df['ood_Type'] = ""
    for i in tqdm.tqdm(range(len(df))):
        # print(prompt.format(df.iloc[i]['Name']))
        messages = [{"role": "user", "content": prompt.format(city, df.iloc[i]['Name'])}]
        res = model(messages)
        type_list_tmp = []
        
        if "河边" in res:
            Type_count['河边'] += 1
            type_list_tmp.append('河边')
        if "江边" in res and city not in ['beijing','suzhou']:
            Type_count['江边'] += 1
            type_list_tmp.append('江边')
        if "海边" in res and city not in ['beijing','chengdu','chongqing','hangzhou','nanjing','suzhou','wuhan']:
            Type_count['海边'] += 1
            type_list_tmp.append('海边')
        if "古风写真" in res:
            Type_count['古风写真'] += 1
            type_list_tmp.append('古风写真')
        if "约会圣地" in res:
            Type_count['约会圣地'] += 1
            type_list_tmp.append('约会圣地')
        if len(type_list_tmp) == 0:
            Type_count['其它'] += 1
            type_list_tmp.append('其它')
        df.loc[i, 'ood_Type'] = ','.join(type_list_tmp)
    print(Type_count)
    df.to_csv(f'./{city}/attractions_tag.csv', index=False)