import pandas as pd
import json
import ast
from numpy import array 
import numpy as np
import glob
import os
import re


PROMPT_TEMPLATES = [
    {
        "name": "template_1",
        "template": "{question_text}\nAnswers:\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:"
    },
    {
        "name": "template_2",
        "template": "{question_text}\nANSWERS:\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nANSWER:"
    },
    {
        "name": "template_3",
        "template": "{question_text}\nAnswers::\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer::"
    },
    {
        "name": "template_4",
        "template": "{question_text}\nANSWERS::\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nANSWERS::"
    },
    {
        "name": "template_5",
        "template": "{question_text}\nAnswers:\nA) {A}\nB) {B}\nC) {C}\nD) {D}\nAnswer:"
    }
]


def process_parquet_folder_to_json_prompts(parquet_folder):
    output_data = []
    label_map_to_abcd = {'1': 'A', '2': 'B', '3': 'C', '4': 'D'}
    expected_final_labels = ['A', 'B', 'C', 'D']

    all_parquet_files = glob.glob(os.path.join(parquet_folder, "*.parquet"))
    if not all_parquet_files:
        print(f"No .parquet files found in {parquet_folder}")
        return

    df_list = []
    for f_path in all_parquet_files:
        try:
            df_list.append(pd.read_parquet(f_path))
            print(f"Successfully read {f_path}")
        except Exception as e:
            print(f"Error reading Parquet file {f_path}: {e}")
    
    if not df_list:
        print("No dataframes were loaded. Exiting.")
        return
        
    merged_df = pd.concat(df_list, ignore_index=True)
    print(f"Merged {len(all_parquet_files)} parquet files into a single DataFrame with {len(merged_df)} rows.")
    merged_df = merged_df.to_dict(orient='records')
    for row_number, row in enumerate(merged_df, 1):
        try:
            row_id = row.get('id', f'RowIndex_{row_number}')
            question_text = row['question']
            choices_input = row['choices'] # This is the field from Parquet
            answer_key_original = str(row['answerKey']).strip()

            # 1. Parse the 'choices' input
            choices_data_parsed = None
            if isinstance(choices_input, str):
                choices_str_for_error_msg = choices_input # Keep original string for error
                try:
                    safe_globals = {"array": array, "__builtins__": {}, "None": None, "True": True, "False": False}
                    choices_data_parsed = eval(choices_input, safe_globals, {})
                except Exception as e_eval:
                    print(f"Skipping row {row_number} (ID: {row_id}) due to eval error on 'choices' string: {e_eval}. Problematic 'choices' string: '{choices_str_for_error_msg}'")
                    continue
            elif isinstance(choices_input, dict):
                # If Pandas/Parquet already deserialized it into a dict (e.g., with np.ndarray values)
                choices_data_parsed = choices_input
                choices_str_for_error_msg = str(choices_input) # For consistent error reporting if needed later
            else:
                print(f"Skipping row {row_number} (ID: {row_id}): 'choices' column is of unexpected type {type(choices_input)}. Value: {choices_input}")
                continue

            # Ensure choices_data_parsed is a dict and extract lists
            # Convert numpy arrays to Python lists if they are present.
            if not isinstance(choices_data_parsed, dict):
                print(f"Skipping row {row_number} (ID: {row_id}): Parsed 'choices' is not a dict. Parsed data: {choices_data_parsed}, Original: '{choices_str_for_error_msg}'")
                continue

            try:
                text_data = choices_data_parsed.get('text')
                label_data = choices_data_parsed.get('label')

                if isinstance(text_data, np.ndarray):
                    option_texts_list = text_data.tolist()
                elif isinstance(text_data, list):
                    option_texts_list = list(text_data) # Ensure it's a mutable list copy
                else:
                    raise ValueError("'text' field is not a list or numpy array")

                if isinstance(label_data, np.ndarray):
                    original_labels_list = [str(label).strip() for label in label_data.tolist()]
                elif isinstance(label_data, list):
                    original_labels_list = [str(label).strip() for label in label_data]
                else:
                    raise ValueError("'label' field is not a list or numpy array")

            except (KeyError, AttributeError, ValueError) as e_struct:
                print(f"Skipping row {row_number} (ID: {row_id}) due to structure error in parsed 'choices': {e_struct}. Parsed data: {choices_data_parsed}, Original: '{choices_str_for_error_msg}'")
                continue


            # 2. Ensure exactly 4 options
            if len(option_texts_list) != 4 or len(original_labels_list) != 4:
                # print(f"Skipping row {row_number} (ID: {row_id}): Not exactly 4 options. Found {len(option_texts_list)} texts, {len(original_labels_list)} labels.")
                continue

            # 3. Convert option labels if they are 1,2,3,4 and prepare final answer key
            current_options = {}
            final_answer_key = answer_key_original

            is_numeric_labeling = all(label in label_map_to_abcd for label in original_labels_list)
            
            if is_numeric_labeling:
                final_answer_key = label_map_to_abcd.get(answer_key_original, answer_key_original)
                for i, original_label in enumerate(original_labels_list):
                    final_label = label_map_to_abcd.get(original_label)
                    if final_label:
                        current_options[final_label] = option_texts_list[i]
                    else:
                        current_options = None
                        break
            else:
                temp_options_from_alpha = {}
                valid_alpha_labels = True
                for i, original_label in enumerate(original_labels_list):
                    ol_upper = original_label.upper()
                    if ol_upper in expected_final_labels:
                        temp_options_from_alpha[ol_upper] = option_texts_list[i]
                    else:
                        valid_alpha_labels = False
                        break
                
                if valid_alpha_labels and len(temp_options_from_alpha) == 4 and all(k in temp_options_from_alpha for k in expected_final_labels):
                    current_options = temp_options_from_alpha
                    final_answer_key = answer_key_original.upper()
                else:
                    current_options = None


            if current_options is None or not all(label in current_options for label in expected_final_labels):
                # print(f"Skipping row {row_number} (ID: {row_id}): Failed to map options to A,B,C,D. Current options: {current_options}")
                continue
            
            if final_answer_key not in expected_final_labels:
                # print(f"Skipping row {row_number} (ID: {row_id}): Final answer key '{final_answer_key}' is not A,B,C, or D.")
                continue

            # 5. Filter: question_text < 8
            words = question_text.split()
            
            if len(words) != 9:
                # print(f"Skipping row {row_number} (ID: {row_id}): Word count ({len(words)}) >= 8. Question: '{question_part_for_word_count}'")
                continue

            # 6. Store the result
            output_data.append({
                "question": question_text,
                "current_options": current_options,
                "answer": final_answer_key
            })

        # Consolidated error handling for one row processing
        except KeyError as ke:
            # This might catch issues if 'question', 'choices', 'answerKey' are missing from the row
            print(f"Skipping row {row_number} (ID: {row_id}) due to KeyError: {ke}. Check Parquet column names. Row data (partial): { {k: v for k, v in row.items() if k in ['question', 'choices', 'answerKey', 'id']} }")
        except Exception as e:
            # Catch any other unexpected errors during a single row's processing.
            print(f"Skipping row {row_number} (ID: {row_id}) due to an unexpected error: {e}. Choices input was: '{choices_input if 'choices_input' in locals() else 'N/A'}'")
    return output_data


def tokenize_text(text):
    """
    Splits text into a list of words. Simply splits by whitespace and removes empty strings.
    """
    if not text:
        return []
    return [token for token in re.split(r'\s+', text.strip()) if token]


def generate_prompts_and_indices(json_data, templates, output_dir, print_sample_count=1):
    """
    Generates prompts based on templates and calculates the positions of question and answer words in the prompt.
    Args:
        json_data (list): The input data list.
        templates (list): The list of templates.
        output_dir (str): The output directory.
        print_sample_count (int): The number of samples with indices to print for each template.
    """
    os.makedirs(output_dir, exist_ok=True)

    placeholder_regex = r"({question_text}|{A}|{B}|{C}|{D})"
    option_placeholders = ["{A}", "{B}", "{C}", "{D}"]

    for template_info_idx, t_info in enumerate(templates):
        template_str = t_info["template"]
        template_name = t_info.get("name", f"template_{template_info_idx+1}")
        output_filepath = os.path.join(output_dir, f"{template_name}.jsonl")
        
        processed_items_for_file = []
        num_samples_printed_for_template = 0

        for item_data_idx, item_data in enumerate(json_data):
            question_content = item_data['question']
            options_content = item_data.get('options')
            if options_content is None and 'current_options' in item_data:
                options_content = item_data['current_options']
            
            if not options_content:
                print(f"Warning: Missing options for question in template {template_name}: {question_content[:50]}...")
                continue
                
            answer_letter = item_data['answer']
            
            fill_values = {
                "{question_text}": question_content,
                "{A}": options_content.get('A', ''),
                "{B}": options_content.get('B', ''),
                "{C}": options_content.get('C', ''),
                "{D}": options_content.get('D', '')
            }
            
            prompt_words_list = []
            question_indices = []
            all_options_indices_map = {} # To store indices for A, B, C, D options

            current_word_idx = 0
            template_parts = re.split(placeholder_regex, template_str)
            
            for part in template_parts:
                if not part:
                    continue

                if part in fill_values: # This is a placeholder
                    content_to_insert = fill_values[part]
                    words_to_insert = tokenize_text(content_to_insert)
                    
                    current_placeholder_token_indices = [] # For current placeholder (Q or A,B,C,D)

                    for word in words_to_insert:
                        prompt_words_list.append(word)
                        current_placeholder_token_indices.append(current_word_idx)
                        
                        if part == "{question_text}":
                            question_indices.append([current_word_idx]) # Each word is an item
                        
                        current_word_idx += 1
                    
                    # If it was an option placeholder (A, B, C, D) and had content
                    if part in option_placeholders and current_placeholder_token_indices:
                        option_key = part[1:-1] # 'A', 'B', 'C', or 'D'
                        all_options_indices_map[option_key] = [current_placeholder_token_indices] # Store as [[idx1, idx2,...]]
                
                else: # This is template's literal text
                    template_segment_words = tokenize_text(part)
                    for word in template_segment_words:
                        prompt_words_list.append(word)
                        current_word_idx += 1
            
            complete_prompt = " ".join(prompt_words_list)
            
            # Derive answer_word_indices from all_options_indices_map
            # It will be in the format [[idx1, idx2,...]] or [] if answer option was empty/not found
            final_answer_indices = all_options_indices_map.get(answer_letter, []) 

            # --- Sample Output Logic ---
            if num_samples_printed_for_template < print_sample_count:
                print(f"\n--- Sample for Template: '{template_name}', Item #{item_data_idx + 1} ---")
                annotated_prompt_str = " ".join([f"{word}[{i}]" for i, word in enumerate(prompt_words_list)])
                print(f"Annotated Prompt: {annotated_prompt_str}")
                print(f"Question Indices: {question_indices}")
                print(f"All Option Indices: {all_options_indices_map}")
                print(f"Answer Letter: {answer_letter}")
                print(f"Derived Answer Indices: {final_answer_indices}")
                print("--- End Sample ---")
                num_samples_printed_for_template += 1

            processed_items_for_file.append({
                "prompt": complete_prompt,
                "answer": answer_letter,
                "question_indices": question_indices,
                "options_indices": all_options_indices_map,
                "answer_indices": final_answer_indices
            })

        with open(output_filepath, 'w', encoding='utf-8') as f:
            for entry in processed_items_for_file:
                f.write(json.dumps(entry) + '\n')
        print(f"Processed data for {template_name} saved to {output_filepath}")


if __name__ == "__main__":
    base_dir = "LLM-Sensitivity/explain_demo/data"
    dataset = "arc"
    parquet_source_folder = os.path.join(base_dir, dataset, "parquet")
    output_directory = os.path.join(base_dir, dataset, "raw_data")

    json_data = process_parquet_folder_to_json_prompts(parquet_source_folder)
    
    generate_prompts_and_indices(json_data, PROMPT_TEMPLATES, output_directory)