import os
import json
from utils.check_vocab_tokenizer import get_model
from accelerate import Accelerator

accelerator = Accelerator()

model_names = [
    'llama-2-7b'
]

device_num = [0,1,2,3]

def get_prompt_word_to_token_indices_map(prompt_text, tokenizer, device, start_pos_offset):
    input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
    
    token_list = input_ids[0]
    
    prompt_words_split = prompt_text.split()
    
    word_to_token_map = []
    
    current_tokens_for_this_word = []
    reconstructed_word_from_tokens = ""
    current_word_idx_in_prompt_split = 0 
    
    for i in range(start_pos_offset, token_list.numel()):
        if current_word_idx_in_prompt_split >= len(prompt_words_split):
            break 

        token_id_tensor_slice = token_list[i:i+1] 

        decoded_token_str = tokenizer.decode(token_id_tensor_slice).strip()
        
        current_tokens_for_this_word.append(i) 
        
        reconstructed_word_from_tokens += decoded_token_str
        
        target_word_from_split = prompt_words_split[current_word_idx_in_prompt_split]

        if reconstructed_word_from_tokens == target_word_from_split:
            word_to_token_map.append(list(current_tokens_for_this_word))
            current_tokens_for_this_word = [] 
            reconstructed_word_from_tokens = "" 
            current_word_idx_in_prompt_split += 1 
        elif len(reconstructed_word_from_tokens) > len(target_word_from_split) and \
             reconstructed_word_from_tokens.startswith(target_word_from_split):
            pass 

    if current_word_idx_in_prompt_split < len(prompt_words_split):
        while len(word_to_token_map) < len(prompt_words_split):
            word_to_token_map.append([])
            
    return word_to_token_map


# 定义输入和输出路径
main_path_root = 'LLM-Sensitivity'
arc_data_base_path = os.path.join(main_path_root, 'explain_demo', 'data', 'mmlu')
raw_data_dir = os.path.join(arc_data_base_path, 'raw_data')


for model_name in model_names:
    output_data_dir_base = os.path.join(arc_data_base_path, 'model_data', model_name)
    os.makedirs(output_data_dir_base, exist_ok=True)
    
    print(f"Processing for model: {model_name}...")
    # 加载模型和 tokenizer
    model, tokenizer = get_model(device_num, model_name)
    
    # 根据模型名称确定 start_pos
    start_pos = 0 if 'gpt' in model_name or 'qwen' in model_name or 'olmoe' in model_name or 'Phi' in model_name or 'olmo' in model_name else 1
    
    for filename in os.listdir(raw_data_dir):
        if filename.endswith(".jsonl"):
            input_filepath = os.path.join(raw_data_dir, filename)
            output_filepath = os.path.join(output_data_dir_base, filename)
            
            print(f"  Processing file: {filename}...")
            
            processed_data_lines = []
            with open(input_filepath, 'r', encoding='utf-8') as infile:
                for line_number, line in enumerate(infile):
                    try:
                        data = json.loads(line)
                        prompt_text = data["prompt"]
                        prompt_word_tokens_map = get_prompt_word_to_token_indices_map(
                            prompt_text, tokenizer, device_num[0], start_pos
                        )
                        
                        all_players_token_indices_groups = []
                        
                        if "question_indices" in data:
                            for word_indices_group in data["question_indices"]:
                                current_group_tokens = []
                                for word_idx in word_indices_group:
                                    if 0 <= word_idx < len(prompt_word_tokens_map):
                                        current_group_tokens.extend(prompt_word_tokens_map[word_idx])
                                    else:
                                        pass 
                                if current_group_tokens: 
                                    all_players_token_indices_groups.append(current_group_tokens)
                        
                        if "options_indices" in data:
                            for option_key in sorted(data["options_indices"].keys()): 
                                option_word_indices_groups = data["options_indices"][option_key]
                                for word_indices_group in option_word_indices_groups:
                                    current_group_tokens = []
                                    for word_idx in word_indices_group:
                                        if 0 <= word_idx < len(prompt_word_tokens_map):
                                            current_group_tokens.extend(prompt_word_tokens_map[word_idx])
                                        else:
                                            pass
                                    if current_group_tokens:
                                        all_players_token_indices_groups.append(current_group_tokens)
                                        
                        data['all_players'] = all_players_token_indices_groups
                        processed_data_lines.append(data)
                        
                    except json.JSONDecodeError:
                        print(f"    Error decoding JSON in file {filename}, line {line_number+1}")
                    except Exception as e:
                        print(f"    Error processing line {line_number+1} in {filename}: {e}")
                        print(f"    Problematic line data: {line.strip()}")
                     
            with open(output_filepath, 'w', encoding='utf-8') as outfile:
                for processed_item in processed_data_lines:
                    outfile.write(json.dumps(processed_item, ensure_ascii=False) + '\n')
            print(f"  Finished processing {filename}, saved to {output_filepath}")

    del model
    del tokenizer
    accelerator.free_memory() 

print("All models and files processed.")