import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import os
import argparse
import pandas as pd
import numpy as np
from collections import defaultdict
from scipy.sparse import csr_matrix
import time
from PIL import Image
from tqdm import tqdm
import gc
import random
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
from code.utils import get_embedding, get_embedding_llm, generate_and_save_embeddings

def preprocess_data(data_path, output_dir, num_train_questions=None, random_seed=42, n_humans=None):

    preprocessed_data_file = os.path.join(output_dir, f'preprocessed_data.pkl')

    if os.path.exists(preprocessed_data_file):
        print(f"Loading preprocessed data for sample size {n_humans or 'all'} from {preprocessed_data_file}")
        try:
            with open(preprocessed_data_file, 'rb') as f:
                data = pickle.load(f)
                if len(data) == 6:
                    return data
                else:
                    print("Cached data format mismatch, recomputing...")
        except Exception as e:
            print(f"Error loading cached data: {e}. Recomputing...")

    info_path = os.path.join(data_path, 'info.csv')
    responses_path = os.path.join(data_path, 'responses.csv')

    try:
        pd_info = pd.read_csv(info_path)
        pd_responses = pd.read_csv(responses_path)
        if 'QKEY' not in pd_responses.columns or 'key' not in pd_info.columns:
             raise ValueError("Required columns ('QKEY', 'key') not found in CSV files.")
        pd_responses['QKEY'] = pd_responses['QKEY'].astype(int)
    except FileNotFoundError:
        print(f"Error: Data files not found in {data_path}. Ensure info.csv and responses.csv exist.")
        return None
    except Exception as e:
        print(f"Error reading CSV files: {e}")
        return None

    question_ids = pd_info['key'].unique()
    all_user_ids = pd_responses['QKEY'].unique()
    all_user_ids = [int(user_id) for user_id in all_user_ids]
    num_questions = len(question_ids)
    print(f"Found {num_questions} unique questions and {len(all_user_ids)} total unique users")

    # --- User Sampling ---
    random.seed(random_seed)
    if n_humans is not None and n_humans < len(all_user_ids):
        sampled_user_ids = random.sample(all_user_ids, n_humans)
        print(f"Sampling {n_humans} users out of {len(all_user_ids)} total users.")
    else:
        sampled_user_ids = all_user_ids
        if n_humans is not None:
             print(f"Requested sample size {n_humans} >= total users {len(all_user_ids)}. Using all users.")
        else:
             print("Processing all users (n_humans=None).")
    # --- End User Sampling ---

    rng = np.random.RandomState(random_seed)
    question_ids_shuffled = np.array(question_ids)
    rng.shuffle(question_ids_shuffled)
    
    if num_train_questions is None:
        split_idx = len(question_ids_shuffled) // 2
    else:
        split_idx = min(num_train_questions, len(question_ids_shuffled) - 1)
    
    print(f"Using {split_idx} questions for training and {len(question_ids_shuffled) - split_idx} for testing")
    train_question_keys = question_ids_shuffled[:split_idx]
    test_question_keys = question_ids_shuffled[split_idx:]

    print(f"Split questions into {len(train_question_keys)} for training and {len(test_question_keys)} for testing")

    split_dir = os.path.join(output_dir, 'question_splits')
    os.makedirs(split_dir, exist_ok=True)

    train_file = os.path.join(split_dir, 'train_questions.json')
    test_file = os.path.join(split_dir, 'test_questions.json')

    train_questions_dict = {}
    test_questions_dict = {}

    def format_question(q_id, info_df):
        question_info = info_df[info_df['key'] == q_id]
        if question_info.empty: return None
        question = question_info['question'].values[0]
        option_mapping_str = question_info['option_mapping'].values[0]
        try:
            option_mapping = eval(option_mapping_str)
            options_str = "\n".join([f'{{"{chr(65 + int(k) - 1)}":"{v}"}}'
                                     for k, v in option_mapping.items() if k != 99.0 and pd.notna(k)])
            return f"{question}\n{options_str}"
        except Exception as e:
            print(f"Warning: Could not process options for question {q_id}: {e}. Skipping question formatting.")
            return None

    for question_id in train_question_keys:
        formatted_q = format_question(question_id, pd_info)
        if formatted_q: train_questions_dict[question_id] = formatted_q
    for question_id in test_question_keys:
        formatted_q = format_question(question_id, pd_info)
        if formatted_q: test_questions_dict[question_id] = formatted_q

    try:
        with open(train_file, 'w') as f:
            json.dump(train_questions_dict, f, indent=4)
        with open(test_file, 'w') as f:
            json.dump(test_questions_dict, f, indent=4)
        print(f"Saved train questions ({len(train_questions_dict)}) to {train_file}")
        print(f"Saved test questions ({len(test_questions_dict)}) to {test_file}")
    except Exception as e:
        print(f"Error saving question split JSON files: {e}")

    user_responses_train = defaultdict(list)
    user_responses_test = defaultdict(list)

    print(f"Preprocessing responses for {len(sampled_user_ids)} sampled users...")
    processed_users_count = 0
    for user_id in tqdm(sampled_user_ids, desc="Preprocessing user responses"):
        user_data = pd_responses[pd_responses['QKEY'] == user_id]

        has_train_response = False
        has_test_response = False

        # Process train questions
        for question_id, question_text in train_questions_dict.items():

            answer_val = user_data[question_id].iloc[0]

            if pd.isna(answer_val) or answer_val == "Refused":
                question_info = pd_info[pd_info['key'] == question_id]
                try:
                    options = eval(question_info['references'].values[0])
                    option_values = eval(question_info['option_ordinal'].values[0])
                    max_val = max(option_values)
                    min_val = min(option_values)
                    answer_letter = None
                    answer_ordinal = None
                    answer_formatted = f''

                    user_responses_train[user_id].append((question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val))
                    has_train_response = True
                except Exception as e:
                    print(f"Error processing question {question_id} for user {user_id}: {e}")
                    continue
            else:
                question_info = pd_info[pd_info['key'] == question_id]
                try:
                    options = eval(question_info['references'].values[0])
                    option_values = eval(question_info['option_ordinal'].values[0])
                    max_val = max(option_values)
                    min_val = min(option_values)
                    answer_idx = options.index(answer_val)
                    answer_letter = chr(65 + answer_idx)
                    answer_ordinal = option_values[answer_idx]
                    answer_formatted = f'{{"{answer_letter}":"{answer_val}"}}'

                    user_responses_train[user_id].append((question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val))
                    has_train_response = True
                except Exception as e:
                    print(f"Error processing question {question_id} for user {user_id}: {e}")
                    continue

        # Process test questions
        for question_id, question_text in test_questions_dict.items():
            answer_val = user_data[question_id].iloc[0]

            if pd.isna(answer_val) or answer_val == "Refused":
                question_info = pd_info[pd_info['key'] == question_id]
                try:
                    options = eval(question_info['references'].values[0])
                    option_values = eval(question_info['option_ordinal'].values[0])
                    max_val = max(option_values)
                    min_val = min(option_values)
                    answer_letter = None
                    answer_ordinal = None
                    answer_formatted = f''
                except Exception as e:
                    print(f"Error processing question {question_id} for user {user_id}: {e}")
                    continue
                user_responses_test[user_id].append((question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val))
                has_test_response = True

            else:
                question_info = pd_info[pd_info['key'] == question_id]
                if question_info.empty: continue
                try:
                    options = eval(question_info['references'].values[0])
                    option_values = eval(question_info['option_ordinal'].values[0])
                    max_val = max(option_values)
                    min_val = min(option_values)
                    if answer_val not in options:
                        continue
                    answer_idx = options.index(answer_val)
                    answer_letter = chr(65 + answer_idx)
                    answer_ordinal = option_values[answer_idx]
                    answer_formatted = f'{{"{answer_letter}":"{answer_val}"}}'

                    user_responses_test[user_id].append((question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val))
                    has_test_response = True
                except Exception as e:
                    print(f"Error processing question {question_id} for user {user_id}: {e}")
                    continue

        if has_train_response or has_test_response:
            processed_users_count += 1

    final_sampled_user_ids = [uid for uid in sampled_user_ids if uid in user_responses_train or uid in user_responses_test]
    user_responses_train = {uid: resp for uid, resp in user_responses_train.items() if uid in final_sampled_user_ids}
    user_responses_test = {uid: resp for uid, resp in user_responses_test.items() if uid in final_sampled_user_ids}

    print(f"Finished preprocessing. Found valid responses for {len(final_sampled_user_ids)} users out of {len(sampled_user_ids)} sampled.")
    print(f"Created train responses for {len(user_responses_train)} users")
    print(f"Created test responses for {len(user_responses_test)} users")

    print(f"Saving preprocessed data for {len(final_sampled_user_ids)} users to {preprocessed_data_file}")
    try:
        with open(preprocessed_data_file, 'wb') as f:
            pickle.dump((user_responses_train, user_responses_test, question_ids, final_sampled_user_ids, list(train_questions_dict.keys()), list(test_questions_dict.keys())), f,
                       protocol=pickle.HIGHEST_PROTOCOL)
    except Exception as e:
        print(f"Error saving preprocessed data cache: {e}")

    return user_responses_train, user_responses_test, question_ids, final_sampled_user_ids, list(train_questions_dict.keys()), list(test_questions_dict.keys())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='preprocessed_data/opinionqa',
                        help='Directory to save all output files')
    parser.add_argument('--data_path', type=str,
                        default='data/opinionqa/American_Trends_Panel_W92/')
    parser.add_argument('--force_recompute', action='store_true', default=False,
                        help='Force recomputation of embeddings even if they exist')
    parser.add_argument('--force_preprocess', action='store_true', default=False,
                        help='Force recomputation of preprocessed data even if cache exists')
    parser.add_argument('--num_train_questions', type=int, default=None,
                        help='Number of questions to use for training. Remaining will be for testing.')
    parser.add_argument('--n_humans', type=int, default=500,
                        help='Number of users to sample')
    parser.add_argument('--random_seed', type=int, default=42,
                        help='Random seed for sampling and splits')
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    preprocessed_data_file = os.path.join(args.output_dir, f'preprocessed_data.pkl')
    if args.force_preprocess and os.path.exists(preprocessed_data_file):
        print(f"Force preprocess: Deleting existing cache file {preprocessed_data_file}")
        os.remove(preprocessed_data_file)

    print(f"Starting preprocessing with n_humans={args.n_humans}, random_seed={args.random_seed}")
    preprocess_result = preprocess_data(
        args.data_path, args.output_dir,
        num_train_questions=args.num_train_questions,
        random_seed=args.random_seed,
        n_humans=args.n_humans
    )

    if preprocess_result is None:
        print("Preprocessing failed. Exiting.")
        exit(1)

    user_responses_train, user_responses_test, question_ids, sampled_user_ids, train_questions, test_questions = preprocess_result
    print(f"Preprocessing complete. Obtained data for {len(sampled_user_ids)} users.")
    print(f"Train responses available for {len(user_responses_train)} users.")
    print(f"Test responses available for {len(user_responses_test)} users.")

    # Define output directories for embeddings
    train_embs_dir = os.path.join(args.output_dir, 'human_embeddings_train')
    test_embs_dir = os.path.join(args.output_dir, 'human_embeddings_test')

    if not sampled_user_ids:
        print("No users found after preprocessing. Skipping embedding generation.")
    else:

        model = None
        tokenizer = None

        logs_file = os.path.join(args.output_dir, "logs.json")
        start_time = time.time()

        # Generate embeddings for the training set
        train_processed_count, train_error_count = generate_and_save_embeddings(
            model, tokenizer, user_responses_train, train_embs_dir, sampled_user_ids, args.force_recompute, desc_suffix="Train"
        )

        # Generate embeddings for the test set
        test_processed_count, test_error_count = generate_and_save_embeddings(
            model, tokenizer, user_responses_test, test_embs_dir, sampled_user_ids, args.force_recompute, desc_suffix="Test"
        )

        # Clean up model and clear memory before logging
        del model
        del tokenizer
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        total_time = time.time() - start_time
        print(f"Finished all embedding generation in {total_time:.2f} seconds.")

        log_data = {
            "gen_embs_time_seconds": total_time,
            "requested_n_humans": args.n_humans,
            "users_after_preprocessing": len(sampled_user_ids),
            "train_embeddings_processed_or_loaded": train_processed_count,
            "train_embedding_errors": train_error_count,
            "test_embeddings_processed_or_loaded": test_processed_count,
            "test_embedding_errors": test_error_count,
            "random_seed": args.random_seed,
            "num_train_questions": args.num_train_questions if args.num_train_questions is not None else len(train_questions)
        }
        try:
            with open(logs_file, "w") as f:
                json.dump(log_data, f, indent=4)
            print(f"Processing logs saved to {logs_file}")
        except Exception as e:
            print(f"Error saving logs: {e}")


    print("Script finished.")
