
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import os
import argparse
import pandas as pd
import numpy as np
from collections import defaultdict
from scipy.sparse import csr_matrix
from transformers import AutoTokenizer, AutoProcessor
import time
from PIL import Image
from tqdm import tqdm
import gc
import random
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM
import json

CONCETPS =[
[3, 32, 38, 237],
[3, 32, 39, 225],
[3, 32, 39, 228],
[3, 32, 39, 230],
[3, 32, 39, 231],
[3, 32, 39, 232],
[3, 32, 42, 211],
[3, 32, 42, 212],
[3, 32, 144, 204],
[3, 49, 62, 251],
]

def preprocess_data(data_path, output_dir, test_split=0.5, random_seed=42, n_humans=None, use_misconceptions_annotations=False):
    preprocessed_data_file = os.path.join(output_dir, f'preprocessed_data.pkl')

    if os.path.exists(preprocessed_data_file):
        print(f"Loading preprocessed data for n_humans {n_humans or 'all'} from {preprocessed_data_file}")
        try:
            with open(preprocessed_data_file, 'rb') as f:
                data = pickle.load(f)
                if len(data) == 6:
                    return data
                else:
                    print("Cached data format mismatch, recomputing...")
        except Exception as e:
            print(f"Error loading cached data: {e}. Recomputing...")
            
    responses_path = os.path.join(data_path, 'train_task_3_4.csv')
    question_metadata_path = os.path.join(data_path, 'metadata', 'question_metadata_task_3_4.csv')
    selected_questions_path = os.path.join(data_path, 'selected_questions.csv') #key, question, option_mapping, option_ordinal

    try:
        pd_responses = pd.read_csv(responses_path) #QuestionId,UserId,AnswerId,IsCorrect,CorrectAnswer,AnswerValue
        pd_responses['QuestionId'] = pd_responses['QuestionId'].astype(str)
        pd_responses['UserId'] = pd_responses['UserId'].astype(str)
        pd_questions = pd.read_csv(question_metadata_path)
        pd_questions['QuestionId'] = pd_questions['QuestionId'].astype(str)
        pd_questions = pd_questions[pd_questions['SubjectId'].apply(lambda x: any(all(str(concept_id) in x for concept_id in topic) for topic in CONCETPS))]
        print("pd_questions", pd_questions)    
        print(f"Filtered questions shape: {pd_questions.shape}")
        print("question_ids", pd_questions['QuestionId'].unique())
        # filter responses which have QuestionId in pd_questions
        pd_responses = pd_responses[pd_responses['QuestionId'].isin(pd_questions['QuestionId'])]
        # count number of unique users
        pd_responses = pd_responses[pd_responses['UserId'].notna()]
        print(f"number of unique users in responses: {len(pd_responses['UserId'].unique())}")
        print(f"Loaded responses data from {responses_path}")
        print(f"Data shape: {pd_responses.shape}")
    except FileNotFoundError:
        print(f"Error: Data files not found in {data_path}. Ensure info.csv and responses.csv exist.")
        return None
    except Exception as e:
        print(f"Error reading CSV files: {e}")
        return None

    question_ids = pd_questions['QuestionId'].unique()
    all_user_ids = pd_responses['UserId'].unique().tolist()
    all_user_ids = pd_responses['UserId'].unique()
    all_user_ids = [int(user_id) for user_id in all_user_ids]
    num_questions = len(question_ids)
    print(f"Found {num_questions} unique questions and {len(all_user_ids)} total unique users")

    # --- Question Sampling ---
    def subject_matches_topic(subject_id_str, current_topic_concepts):
        if pd.isna(subject_id_str):
            return False
        return all(f"'{str(concept_id)}'" in subject_id_str or f" {str(concept_id)}," in subject_id_str or f" {str(concept_id)}]" in subject_id_str or f"[{str(concept_id)}," in subject_id_str or f"[{str(concept_id)}]" in subject_id_str for concept_id in current_topic_concepts)
    if os.path.exists(selected_questions_path):
        print(f"Found existing selected questions file at {selected_questions_path}. Using these questions instead of sampling.")
        pd_selected_questions = pd.read_csv(selected_questions_path)
        pd_selected_questions['key'] = pd_selected_questions['key'].astype(str)
        question_ids_shuffled = np.array(pd_selected_questions['key'].tolist())
        print(f"Loaded {len(question_ids_shuffled)} questions from existing file.")
    else:
        question_response_counts = pd_responses['QuestionId'].value_counts()
        print("Overall question_response_counts for questions matching any CONCETPS topic (before per-topic selection):", question_response_counts.shape[0])
        if not question_response_counts.empty:
            print(f"Overall response counts stats: min={question_response_counts.min()}, max={question_response_counts.max()}, mean={question_response_counts.mean():.2f}")
        else:
            print("No question responses found to calculate statistics.")
        final_selected_question_ids = set()
        print(f"Selecting top 4 questions per topic from {len(CONCETPS)} topics...")

        for i, topic_concept_ids in enumerate(CONCETPS):
            questions_for_this_topic_df = pd_questions[
                pd_questions['SubjectId'].apply(lambda x: subject_matches_topic(x, topic_concept_ids))
            ]
            
            current_topic_qids = questions_for_this_topic_df['QuestionId'].unique()
            relevant_topic_qids_with_responses = [qid for qid in current_topic_qids if qid in question_response_counts.index]

            relevant_topic_qids_with_responses = [qid for qid in relevant_topic_qids_with_responses]
                
            topic_q_response_counts = question_response_counts[question_response_counts.index.isin(relevant_topic_qids_with_responses)].sort_values(ascending=False)
            
            top_n_for_topic = topic_q_response_counts.head(4).index.tolist()
            
            print(f"Topic {i+1} (Concepts: {topic_concept_ids}): Found {len(current_topic_qids)} questions ({len(relevant_topic_qids_with_responses)} with responses after exclusions). Selected top {len(top_n_for_topic)}: {top_n_for_topic}")
            final_selected_question_ids.update(top_n_for_topic)
        question_ids_shuffled = np.array(list(final_selected_question_ids))
        
    pd_responses_filtered_questions = pd_responses[pd_responses['QuestionId'].isin(question_ids_shuffled)]
    print(f"Filtered responses to {pd_responses_filtered_questions.shape[0]} interactions involving the {len(question_ids_shuffled)} selected questions.")

    print("Starting new user selection: top users who answered most selected questions AND at least one per relevant topic.")

    topic_to_selected_questions_map = {}
    selected_questions_details = pd_questions[pd_questions['QuestionId'].isin(question_ids_shuffled)][['QuestionId', 'SubjectId']].set_index('QuestionId')

    for topic_idx, topic_concept_ids in enumerate(CONCETPS):
        qids_for_this_topic_in_shuffled = []
        for qid in question_ids_shuffled: 
            if qid in selected_questions_details.index:
                subject_id_str = selected_questions_details.loc[qid, 'SubjectId']
                if subject_matches_topic(subject_id_str, topic_concept_ids):
                    qids_for_this_topic_in_shuffled.append(qid)
        
        if qids_for_this_topic_in_shuffled: 
            topic_to_selected_questions_map[topic_idx] = qids_for_this_topic_in_shuffled
        else:
            print(f"Info: Topic {topic_idx+1} (Concepts: {topic_concept_ids}) has no questions within the final question_ids_shuffled set. It won't be required for user topic coverage.")
    
    required_topic_indices = set(topic_to_selected_questions_map.keys())
    num_required_topics = len(required_topic_indices)
    
    if CONCETPS and num_required_topics == 0:
        print("Warning: CONCETPS is defined, but no selected questions mapped to any topic. User topic coverage criterion will be trivially met by all users, or effectively ignored.")
    elif not CONCETPS:
        print("Info: CONCETPS is empty. User topic coverage criterion is trivially met.")
    else:
        print(f"Number of topics with questions in the selected set (required for coverage): {num_required_topics}")


    eligible_users_data = []
    user_groups = pd_responses_filtered_questions.groupby('UserId')

    for user_id, user_responses_df in user_groups:
        answered_qids_by_user = set(user_responses_df['QuestionId'].unique())
        num_answered_total_selected = len(answered_qids_by_user)

        user_covers_all_required_topics = False
        if num_required_topics == 0: 
            user_covers_all_required_topics = True
        else:
            covered_topics_count = 0
            for topic_idx in required_topic_indices:
                qids_for_topic = topic_to_selected_questions_map[topic_idx] 
                if any(qid in answered_qids_by_user for qid in qids_for_topic):
                    covered_topics_count += 1
            
            if covered_topics_count == num_required_topics:
                user_covers_all_required_topics = True

        if user_covers_all_required_topics:
            eligible_users_data.append((user_id, num_answered_total_selected))
    
    if not eligible_users_data:
        print("Warning: No users found who cover all required topics. user_ids_shuffled will be empty.")
        user_ids_shuffled = np.array([])
    else:
        sorted_eligible_users = sorted(eligible_users_data, key=lambda x: x[1], reverse=True)
        
        print(f"Found {len(sorted_eligible_users)} users covering all {num_required_topics} required topic(s).")
        final_user_ids_list = [user_data[0] for user_data in sorted_eligible_users]
        
        if n_humans and n_humans > 0 and n_humans < len(final_user_ids_list):
            user_ids_shuffled = np.array(final_user_ids_list[:n_humans])
            print(f"Selected top {len(user_ids_shuffled)} users (out of {len(final_user_ids_list)} eligible and topic-covering users) based on n_humans limit.")
        else:
            user_ids_shuffled = np.array(final_user_ids_list)
            if n_humans and n_humans > 0:
                    print(f"Requested n_humans ({n_humans}) is >= eligible/topic-covering users ({len(final_user_ids_list)}). Selecting all {len(final_user_ids_list)} such users.")
            else:
                    print(f"No user limit (or n_humans <= 0), selecting all {len(final_user_ids_list)} eligible and topic-covering users.")
    
    print(f"question_ids_shuffled: {len(question_ids_shuffled)}")
    print(f"Final number of users selected: {len(user_ids_shuffled)}")
    print(f"Final number of responses after user and question filtering: {pd_responses.shape[0]}")

    # --- Train/Test Split ---
    train_question_keys = []
    test_question_keys = []
    
    topic_to_questions = {}
    for topic_idx, topic_concept_ids in enumerate(CONCETPS):
        def subject_matches_topic(subject_id_str, current_topic_concepts):
            if pd.isna(subject_id_str):
                return False
            return all(f"'{str(concept_id)}'" in subject_id_str or f" {str(concept_id)}," in subject_id_str or 
                      f" {str(concept_id)}]" in subject_id_str or f"[{str(concept_id)}," in subject_id_str or 
                      f"[{str(concept_id)}]" in subject_id_str for concept_id in current_topic_concepts)
        
        topic_questions = []
        for qid in question_ids_shuffled:
            if qid in selected_questions_details.index:
                subject_id_str = selected_questions_details.loc[qid, 'SubjectId']
                if subject_matches_topic(subject_id_str, topic_concept_ids):
                    topic_questions.append(qid)
        
        if topic_questions:
            topic_to_questions[topic_idx] = topic_questions
    
    for topic_idx, questions in topic_to_questions.items():
        rng = np.random.default_rng(random_seed)
        questions_shuffled = questions.copy()
        rng.shuffle(questions_shuffled)
        
        split_idx = len(questions_shuffled) // 2
        topic_train = questions_shuffled[:split_idx]
        topic_test = questions_shuffled[split_idx:]
        
        print(f"Topic {topic_idx}: Assigned {len(topic_train)} questions to train and {len(topic_test)} to test")
        
        train_question_keys.extend(topic_train)
        test_question_keys.extend(topic_test)
    
    unassigned_questions = [qid for qid in question_ids_shuffled 
                           if not any(qid in topic_qs for topic_qs in topic_to_questions.values())]
    
    if unassigned_questions:
        print(f"Found {len(unassigned_questions)} questions not assigned to any topic. Splitting randomly.")
        rng.shuffle(unassigned_questions)
        split_idx = len(unassigned_questions) // 2
        train_question_keys.extend(unassigned_questions[:split_idx])
        test_question_keys.extend(unassigned_questions[split_idx:])

    print(f"Split questions into {len(train_question_keys)} for training and {len(test_question_keys)} for testing")

    split_dir = os.path.join(output_dir, 'question_splits')
    os.makedirs(split_dir, exist_ok=True)

    train_file = os.path.join(split_dir, 'train_questions.json')
    test_file = os.path.join(split_dir, 'test_questions.json')

    selected_questions_df = pd.read_csv(os.path.join(data_path, 'selected_questions.csv'))
    selected_questions_df['key'] = selected_questions_df['key'].astype(str)
    def format_question(q_id, selected_questions_df):
        question_info = selected_questions_df[selected_questions_df['key'] == str(q_id)]
        if question_info.empty: 
            print(f"Question {q_id} not found in selected_questions_df")
            return None
        
        question = question_info['question'].values[0]
        option_mapping = eval(question_info['option_mapping'].values[0]) 
        

        options_str = "\n".join([f'{{"{chr(65 + int(k) - 1)}":"{v}"}}' for k, v in option_mapping.items()])

        if use_misconceptions_annotations:
            option_misconception = eval(question_info['option_misconceptions'].values[0])  
            misconceptions_str =  "Misconceptions corresponding to each answer option:\n" + "\n".join([f'{chr(65 + i)}-{misconception}' for i, misconception in enumerate(option_misconception)])
            return f"{question}\n{options_str}\n{misconceptions_str}\n"
        else:
            return f"{question}\n{options_str}\n"
        

    train_questions_dict = {}
    test_questions_dict = {}

    for question_id in train_question_keys:
        formatted_result = format_question(question_id, selected_questions_df)
        if formatted_result:
            formatted_q = formatted_result
            train_questions_dict[str(question_id)] = formatted_q

            
    for question_id in test_question_keys:
        formatted_result = format_question(question_id, selected_questions_df)
        if formatted_result:
            formatted_q = formatted_result
            test_questions_dict[str(question_id)] = formatted_q

    print(f"Formatted {len(train_questions_dict)} train questions")
    print(f"Formatted {len(test_questions_dict)} test questions")

    try:
        with open(train_file, 'w') as f:
            json.dump(train_questions_dict, f, indent=4)
        with open(test_file, 'w') as f:
            json.dump(test_questions_dict, f, indent=4)
        print(f"Saved train questions ({len(train_questions_dict)}) to {train_file}")
        print(f"Saved test questions ({len(test_questions_dict)}) to {test_file}")
    except Exception as e:
        print(f"Error saving question split JSON files: {e}")

    user_responses_train = defaultdict(list)
    user_responses_test = defaultdict(list)

    print(f"Preprocessing responses for {len(user_ids_shuffled)} sampled users...")
    processed_users_count = 0

    
    for user_id in tqdm(user_ids_shuffled, desc="Preprocessing user responses"):

        has_train_response = False
        has_test_response = False

        # Process train questions
        for question_id, question_text in train_questions_dict.items():
            user_answers = pd_responses[(pd_responses['UserId'] == str(user_id)) & (pd_responses['QuestionId'] == str(question_id))]['AnswerValue'].values
            if len(user_answers) == 0 or pd.isna(user_answers[0]):
                print(f"No answer found for question {question_id}")

                try:
                    max_val = 4
                    min_val = 1
                    answer_letter = None
                    answer_ordinal = None
                    answer_formatted = f''
                    misconception = None
                
                    user_responses_train[user_id].append((question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val))
                    has_train_response = True
                except Exception as e:
                    print(f"Error processing question {question_id} for user {user_id}: {e}")
                    continue
            else:
                answer_val = user_answers[0]
                print(f"answer_val: {answer_val}")
                is_correct = pd_responses[(pd_responses['UserId'] == str(user_id)) & (pd_responses['QuestionId'] == str(question_id))]['IsCorrect'].values[0]
                try:        
                    max_val = 4
                    min_val = 1
                    answer_ordinal = answer_val
                    answer_letter = chr(65 + int(answer_ordinal) - 1)   
                    answer_text = eval(selected_questions_df[selected_questions_df['key'] == str(question_id)]['option_mapping'].values[0])[answer_ordinal]

                    answer_formatted = f'{{"{answer_letter}":"{answer_text}"}}'
                    user_responses_train[user_id].append((question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val))
                    has_train_response = True
                except Exception as e:
                    print(f"Error processing question {question_id} for user {user_id}: {e}")

        # Process test questions
        for question_id, question_text in test_questions_dict.items():
            # if question_id not in user_data.columns: continue
            user_answers = pd_responses[(pd_responses['UserId'] == user_id) & (pd_responses['QuestionId'] == question_id)]['AnswerValue'].values
            if len(user_answers) == 0 or pd.isna(user_answers[0]):
                try:
                    max_val = 4
                    min_val = 1
                    answer_letter = None
                    answer_ordinal = None
                    answer_formatted = f''
                
                    
                except Exception as e:
                    print(f"Error processing question {question_id} for user {user_id}: {e}")
                user_responses_test[user_id].append((question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val))
                has_test_response = True

            else:
                answer_val = user_answers[0]
                print(f"answer_val: {answer_val}")

                is_correct = pd_responses[(pd_responses['UserId'] == str(user_id)) & (pd_responses['QuestionId'] == str(question_id))]['IsCorrect'].values[0]

                try:
                    max_val = 4
                    min_val = 1
                    answer_ordinal = answer_val
                    answer_letter = chr(65 + int(answer_ordinal) - 1)   
                    answer_text = eval(selected_questions_df[selected_questions_df['key'] == str(question_id)]['option_mapping'].values[0])[answer_ordinal]

                    answer_formatted = f'{{"{answer_letter}":"{answer_text}"}}'


                    user_responses_test[user_id].append((question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val))
                    has_test_response = True
                except Exception as e:
                    print(f"Error processing question {question_id} for user {user_id}: {e}")

        if has_train_response or has_test_response:
            processed_users_count += 1


    final_sampled_user_ids = [uid for uid in user_ids_shuffled if uid in user_responses_train or uid in user_responses_test]
    user_responses_train = {uid: resp for uid, resp in user_responses_train.items() if uid in final_sampled_user_ids}
    user_responses_test = {uid: resp for uid, resp in user_responses_test.items() if uid in final_sampled_user_ids}

    print(f"Finished preprocessing. Found valid responses for {len(final_sampled_user_ids)} users out of {len(user_ids_shuffled)} sampled.")
    print(f"Created train responses for {len(user_responses_train)} users")
    print(f"Created test responses for {len(user_responses_test)} users")

    print(f"Saving preprocessed data for {len(final_sampled_user_ids)} users to {preprocessed_data_file}")
    try:
        with open(preprocessed_data_file, 'wb') as f:
            pickle.dump((user_responses_train, user_responses_test, question_ids, final_sampled_user_ids, list(train_questions_dict.keys()), list(test_questions_dict.keys())), f,
                       protocol=pickle.HIGHEST_PROTOCOL)
    except Exception as e:
        print(f"Error saving preprocessed data cache: {e}")
    question_ids_shuffled = [str(qid) for qid in question_ids_shuffled]
    
    print(f"Final number of users selected: {len(user_ids_shuffled)}")
    print(f"Final number of responses after user and question filtering: {pd_responses.shape[0]}")

    
    return user_responses_train, user_responses_test, question_ids, final_sampled_user_ids, list(train_questions_dict.keys()), list(test_questions_dict.keys())

def get_embedding_eedi(model, tokenizer, user_responses_list, selected_questions_df):

    corrects = []
    answers = []
    
    for question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val in user_responses_list:

        if answer_ordinal is None:
            corrects.append(None)  # Append 0 or some default value
        else:
            correct_answer = selected_questions_df[selected_questions_df['key'] == str(question_id)]['CorrectAnswer'].values[0]
            is_correct = int(answer_ordinal) == int(correct_answer)
            corrects.append(is_correct)

    embedding = np.array(corrects, dtype=np.float32)

    return embedding



def generate_and_save_embeddings(model, tokenizer, user_responses_dict, output_dir, user_ids, force_recompute, desc_suffix="", use_llm=False, data_dir=None, selected_questions_df=None):
    os.makedirs(output_dir, exist_ok=True)
    processed_count = 0
    error_count = 0
    print(f"Generating embeddings for {len(user_ids)} users ({desc_suffix})...")
    for user_id in tqdm(user_ids, desc=f"Generating embeddings ({desc_suffix})"):

        if 'user_' in str(user_id):
            user_file = os.path.join(output_dir, f"{user_id}.npy")
        else:
            user_file = os.path.join(output_dir, f"user_{user_id}.npy")
        if os.path.exists(user_file) and not force_recompute:
            try:
                _ = np.load(user_file) 
                processed_count += 1
                # continue
            except Exception as load_err:
                print(f"Could not load existing embedding for user {user_id} ({desc_suffix}): {load_err}. Regenerating...")

        try:
            embedding = get_embedding_eedi(model, tokenizer, user_responses_dict[user_id], selected_questions_df)
            if embedding is not None:
                np.save(user_file, embedding)
                processed_count += 1
            else:
                print(f"Failed to generate embedding for user {user_id} ({desc_suffix}). Skipping.")
                error_count += 1

            if processed_count % 50 == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        except Exception as e:
            print(f"Error processing user {user_id} during embedding ({desc_suffix}): {e}")
            error_count += 1
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        # break   
    print(f"Finished embedding generation ({desc_suffix}). Successfully generated/loaded embeddings for {processed_count} users.")
    if error_count > 0:
        print(f"Encountered errors during embedding generation ({desc_suffix}) for {error_count} users.")
    return processed_count, error_count


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='preprocessed_data/eedi',
                        help='Directory to save all output files')
    parser.add_argument('--data_path', type=str,
                        default='data/eedi/')
    parser.add_argument('--force_recompute', action='store_true', default=False,
                        help='Force recomputation of embeddings even if they exist')
    parser.add_argument('--force_preprocess', action='store_true', default=False,
                        help='Force recomputation of preprocessed data even if cache exists')
    parser.add_argument('--test_split', type=float, default=0.5,
                        help='Proportion of questions to use for testing (default: 0.5)')
    parser.add_argument('--n_humans', type=int, default=10,
                        help='Number of users to select ')
    parser.add_argument('--n_questions', type=int, default=40,
                        help='Number of questions')
    parser.add_argument('--random_seed', type=int, default=42,
                        help='Random seed for sampling and splits (default: 42)')
    parser.add_argument('--use_misconceptions_annotations', action='store_true', default=False)
    args = parser.parse_args()

    # Set random seeds for reproducibility
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.random_seed)

    os.makedirs(args.output_dir, exist_ok=True)

    preprocessed_data_file = os.path.join(args.output_dir, f'preprocessed_data.pkl')
    if args.force_preprocess and os.path.exists(preprocessed_data_file):
        print(f"Force preprocess: Deleting existing cache file {preprocessed_data_file}")
        os.remove(preprocessed_data_file)

    print(f"Starting preprocessing with n_humans={args.n_humans}, random_seed={args.random_seed}")
    preprocess_result = preprocess_data(
        args.data_path, args.output_dir,
        test_split=args.test_split,
        random_seed=args.random_seed,
        n_humans=args.n_humans,
        use_misconceptions_annotations=args.use_misconceptions_annotations
    )

    if preprocess_result is None:
        print("Preprocessing failed. Exiting.")
        exit(1)

    user_responses_train, user_responses_test, question_ids, sampled_user_ids, train_questions, test_questions = preprocess_result

    print("user_responses_train", user_responses_train)
    print("user_responses_test", user_responses_test)
    print("question_ids", question_ids)
    print("sampled_user_ids", sampled_user_ids)
    print("train_questions", train_questions)
    print("test_questions", test_questions)

    selected_questions_df = pd.read_csv(os.path.join(args.data_path, 'selected_questions.csv'))
    # convert key to string
    selected_questions_df['key'] = selected_questions_df['key'].astype(str)
    selected_questions_df['CorrectAnswer'] = selected_questions_df['CorrectAnswer'].astype(str)
    
    print(f"Preprocessing complete. Obtained data for {len(sampled_user_ids)} users.")
    print(f"Train responses available for {len(user_responses_train)} users.")
    print(f"Test responses available for {len(user_responses_test)} users.")

    # Define output directories for embeddings
    train_embs_dir = os.path.join(args.output_dir, 'human_embeddings_train')
    test_embs_dir = os.path.join(args.output_dir, 'human_embeddings_test')

    if not sampled_user_ids:
        print("No users found after preprocessing. Skipping embedding generation.")
    else:

        model = None
        tokenizer = None

        logs_file = os.path.join(args.output_dir, "logs.json")
        start_time = time.time()

        train_processed_count, train_error_count = generate_and_save_embeddings(
            model, tokenizer, user_responses_train, train_embs_dir, sampled_user_ids, args.force_recompute, desc_suffix="Train", selected_questions_df=selected_questions_df
        )

        test_processed_count, test_error_count = generate_and_save_embeddings(
            model, tokenizer, user_responses_test, test_embs_dir, sampled_user_ids, args.force_recompute, desc_suffix="Test", selected_questions_df=selected_questions_df
        )

        del model
        del tokenizer
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        total_time = time.time() - start_time
        print(f"Finished all embedding generation in {total_time:.2f} seconds.")

        log_data = {
            "gen_embs_time_seconds": total_time,
            "requested_n_humans": args.n_humans,
            "users_after_preprocessing": len(sampled_user_ids),
            "train_embeddings_processed_or_loaded": train_processed_count,
            "train_embedding_errors": train_error_count,
            "test_embeddings_processed_or_loaded": test_processed_count,
            "test_embedding_errors": test_error_count,
            "random_seed": args.random_seed,
            "test_split_ratio": args.test_split
            }
        try:
            with open(logs_file, "w") as f:
                json.dump(log_data, f, indent=4)
            print(f"Processing logs saved to {logs_file}")
        except Exception as e:
            print(f"Error saving logs: {e}")


    print("Script finished.")
