from curses import meta
from hashlib import sha256
import sys
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import pickle
import token
from typing import List

from utils.text_predict import LLMSentimentPredictor

import numpy as np
import torch
import socket
import subprocess
import multiprocessing
import json
import sys
from tqdm.auto import tqdm
import time
import argparse
import shutil

from abc import ABCMeta, abstractmethod


def get_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='gpt-4o-2024-11-20')
    parser.add_argument('--model_short_name', type=str, default='gpt-4o')
    parser.add_argument('--sample_path', type=str, default='lime_samples')
    args,_ = parser.parse_known_args()
    return args


args = get_args()
model_name = args.model_name
model_short_name = args.model_short_name
sample_path = args.sample_path
# %%
# Print args
print(f"Model Name: {model_name}")
print(f"Model Short Name: {model_short_name}")
print(f"Sample Path: {sample_path}")


# %%
import openai
from openai import OpenAI




class OpenAIAPIPredictor(LLMSentimentPredictor):
    
    def __init__(self, model_name='gpt-4o-2024-11-20', **kwargs):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.model = OpenAI()

    @classmethod
    def get_probs(cls, response):
        res = np.zeros(len(cls.labels))
        labels = {label:idx for idx, label in enumerate(cls.labels)}
        res[:] = -100
        probs = response.choices[0].logprobs.content[0].top_logprobs
        for i in range(len(res)):
            if probs[i].token in cls.labels:
                res[labels[probs[i].token]] = probs[i].logprob
        return res


    def _predict(self, texts, logits = True, **kwargs):
        results = []
        model_name=  self.model_name
        if kwargs.get('model_name') is not None:
            model_name = kwargs.get('model_name')
        
        for text in texts:
            user_prompt = self.get_user_prompt(text)[0]
            messages =[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": user_prompt}
                ] 

            # print(input_texts)
            response = None
            for _ in range(10):
                try:
                    response = self.model.chat.completions.create(
                        model=model_name,
                        messages=messages,
                        temperature=1,
                        max_tokens=1,
                        top_p=1.0,
                        frequency_penalty=0.0,
                        presence_penalty=0.0,
                        logprobs=True,
                        top_logprobs=20
                    )
                    results.append(self.get_probs(response))
                    break
                except :
                    pass
            if response is None:
                print(f"Error in response for text: {text}", file=sys.stderr)
                return None
            if len(results) == 0:
                print(f"Error in response for text: {text}, and the response is {response}", file=sys.stderr)
                return None
            # print(response)
        results =  np.stack(results, axis=0)
        if logits:
            return results
        
        results = torch.softmax(torch.tensor(results), dim=1).cpu().numpy()
        return results

import pandas as pd
samples_df = pd.read_csv(f'./{sample_path}/sst_test_samples.csv', sep='\t',index_col=0)

samples_df.loc[:,'sample_sentence'].fillna('',inplace=True)

samples_df.head()

openai_predictor = OpenAIAPIPredictor(model_name)

samples_df['logits_positive'] = -1000.0
samples_df['logits_negative'] = -1000.0
if os.path.exists(f'./{sample_path}/sst_test_samples_{model_short_name}.csv'):
    samples_df = pd.read_csv(f'./{sample_path}/sst_test_samples_{model_short_name}.csv', sep='\t',index_col=0,keep_default_na=False)

# %%
if os.path.exists(f'./samples_pools/sst_test_cached_{model_short_name}.csv'):
    pool_df = pd.read_csv(f'./samples_pools/sst_test_cached_{model_short_name}.csv', sep='\t',index_col=0,keep_default_na=False)
    print("get data from pool")
    sampless = []
    for idx in tqdm(range(max(samples_df['sentence_index'].tolist())+1)):
        local_samples = samples_df[samples_df['sentence_index'] == idx]
        local_samples.reset_index(drop=True, inplace=True)
        local_pool = pool_df[pool_df['sentence_index'] == idx]
        for i in range(len(local_samples)):
            if local_samples.at[i, 'logits_positive'] > -1000.0:
                continue
            if local_samples.at[i,'binary_representation'] in local_pool['binary_representation'].tolist():
                local_samples.at[i, 'logits_positive'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_positive'].values[0]
                local_samples.at[i, 'logits_negative'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_negative'].values[0]
        sampless.append(local_samples)
    samples_df = pd.concat(sampless, ignore_index=True)
            
# %%
import concurrent.futures

# 定义预测函数
def process_row(idx, text):
    if samples_df.loc[idx, 'logits_positive'] > -1000.0:
        return None  # 跳过已处理的行
    res = openai_predictor.predict([text])
    if res is None:
        return None
    return idx, res[0, 1], res[0, 0]  # 返回索引、logits_positive 和 logits_negative

# 分段处理函数
def process_in_chunks(df, chunk_size=1000, max_workers=10):
    finish_flag = True
    for i in range(0, len(df), chunk_size):
        write_flag = False
        chunk_end = min(i + chunk_size, len(df))
        chunk_indices = list(range(i, chunk_end))
        chunk_texts = df.iloc[chunk_indices]['sample_sentence'].tolist()

        # 并行化处理当前批次
        results = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [
                executor.submit(process_row, idx, text)
                for idx, text in zip(chunk_indices, chunk_texts)
            ]
            for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
                result = future.result()
                if result is not None:  # 只处理非空结果
                    results.append(result)
                    write_flag = True
        # 更新 DataFrame
        for idx, logit_pos, logit_neg in results:
            df.loc[idx, 'logits_positive'] = logit_pos
            df.loc[idx, 'logits_negative'] = logit_neg
        for idx in chunk_indices:
            if df.loc[idx, 'logits_positive'] <= -1000.0:
                finish_flag = False

        # 每处理完一个批次，保存当前结果
        if write_flag:
            # 保存当前批次结果
            if os.path.exists(f'./{sample_path}/sst_test_samples_{model_short_name}.csv'):
                shutil.copyfile(f'./{sample_path}/sst_test_samples_{model_short_name}.csv', f'./{sample_path}/sst_test_samples_{model_short_name}_backup.csv')
            df.to_csv(f'./{sample_path}/sst_test_samples_{model_short_name}.csv', sep='\t')
            # 删除备份文件
            if os.path.exists(f'./{sample_path}/sst_test_samples_{model_short_name}_backup.csv'):
                os.remove(f'./{sample_path}/sst_test_samples_{model_short_name}_backup.csv')
        # 打印处理进度
        print(f"Processed and saved rows {i} to {chunk_end - 1}, total {len(df)} rows.")
    if finish_flag:
        print("All rows processed successfully.")
    else:
        print("Some rows were not processed successfully. Please check the logs.")

# 调用分段处理函数
process_in_chunks(samples_df, chunk_size=1000, max_workers=64)



# # %%
samples_df.to_csv(f'./{sample_path}/sst_test_samples_{model_short_name}.csv', sep='\t')
