# %%
from curses import meta
from hashlib import sha256
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import pickle
import token
from typing import List

from networkx import union
# from transformers import BertTokenizer, BertForSequenceClassification,GPT2Tokenizer, GPT2ForSequenceClassification, T5Tokenizer, T5ForConditionalGeneration,
# from transformers import AutoTokenizer,AutoModelForSequenceClassification, AutoModelForCausalLM, AutoTokenizer
from modelscope import AutoModelForCausalLM, AutoTokenizer,GenerationConfig
from utils.llm_predict import QwenPredictor, LLamaPredictor
from utils.text_predict import BertPredictor

import numpy as np
import torch
import socket
import subprocess
import multiprocessing
import json
import sys
from tqdm.auto import tqdm
import time
import argparse

from abc import ABCMeta, abstractmethod

def get_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='qwen/Qwen2.5-1.5B-Instruct')
    parser.add_argument('--model_short_name', type=str, default='qwen_1.5B')
    parser.add_argument('--sample_path', type=str, default='lime_samples')
    parser.add_argument('--start_idx', type=int, default=0)
    args,_ = parser.parse_known_args()
    return args


args = get_args()
model_name = args.model_name
model_short_name = args.model_short_name
sample_path = args.sample_path
# %%
# Print args
print(f"Model Name: {model_name}")
print(f"Model Short Name: {model_short_name}")
print(f"Sample Path: {sample_path}")


# %%
import pandas as pd
samples_df = pd.read_csv(f'./{sample_path}/sst_test_samples.csv', sep='\t',index_col=0)

samples_df.loc[:,'sample_sentence'].fillna('',inplace=True)

samples_df.head()

# %%
if model_short_name.lower().startswith('qwen'):
    llm_predictor = QwenPredictor(model_name)
elif model_short_name.lower().startswith('llama'):
    llm_predictor = LLamaPredictor(model_name)
elif model_short_name.lower().startswith('bert'):
    llm_predictor = BertPredictor(model_name)
else:
    raise ValueError(f"Model {model_short_name} not supported")
# %%

samples_df['logits_positive'] = -1000.0
samples_df['logits_negative'] = -1000.0

# %%
if os.path.exists(f'./{sample_path}/sst_test_samples_{model_short_name}.csv'):
    samples_df = pd.read_csv(f'./{sample_path}/sst_test_samples_{model_short_name}.csv', sep='\t',index_col=0,keep_default_na=False)
# else:
if os.path.exists(f'./samples_pools/sst_test_cached_{model_short_name}.csv'):
    pool_df = pd.read_csv(f'./samples_pools/sst_test_cached_{model_short_name}.csv', sep='\t',index_col=0,keep_default_na=False)
    print("get data from pool")
    sampless = []
    for idx in tqdm(range(max(samples_df['sentence_index'].tolist())+1)):
        local_samples = samples_df[samples_df['sentence_index'] == idx]
        local_samples.reset_index(drop=True, inplace=True)
        local_pool = pool_df[pool_df['sentence_index'] == idx]
        for i in range(len(local_samples)):
            if local_samples.at[i, 'logits_positive'] > -1000.0:
                continue
            if local_samples.at[i,'binary_representation'] in local_pool['binary_representation'].tolist():
                local_samples.at[i, 'logits_positive'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_positive'].values[0]
                local_samples.at[i, 'logits_negative'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_negative'].values[0]
        sampless.append(local_samples)
    samples_df = pd.concat(sampless, ignore_index=True)

    
for idx, text in tqdm(enumerate(samples_df['sample_sentence'].tolist()),total=len(samples_df)):
    if idx < args.start_idx:
        continue
    if samples_df.loc[idx,'logits_positive'] > -1000.0:
        continue
    res = llm_predictor.predict([text])
    samples_df.loc[idx,'logits_positive'] = res[0,1]
    samples_df.loc[idx,'logits_negative'] = res[0,0]
    if idx % 10000 == 0:
        if args.start_idx==0:
            samples_df.to_csv(f'./{sample_path}/sst_test_samples_{model_short_name}.csv', sep='\t')
        else:
            samples_df.to_csv(f'./{sample_path}/sst_test_samples_{model_short_name}_{args.start_idx}.csv', sep='\t')
# %%
if args.start_idx==0:
    samples_df.to_csv(f'./{sample_path}/sst_test_samples_{model_short_name}.csv', sep='\t')
else:
    samples_df.to_csv(f'./{sample_path}/sst_test_samples_{model_short_name}_{args.start_idx}.csv', sep='\t')
