import os
import numpy as np
import torch
import gc
from tqdm import tqdm
from joblib import Parallel, delayed

def get_embedding_llm(model, processor, user_responses_list, data_dir):
    content = []
    for qid, question, answer, emotion, explanation in user_responses_list:
        content.append({"type": "text", "text": f"Response: {answer}\n"})
    messages = [
        {
            "role": "user",
            "content": content
        }
    ]    
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)
    try:
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        embedding = outputs.hidden_states[-1].squeeze(0).mean(dim=0).to(torch.float32).cpu().numpy()
        return embedding
    except Exception as e:
        print(f"Error during model inference for embedding: {e}")
        return None

def get_embedding(model, tokenizer, user_responses_list):
    ordinal_values = []
    for question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val in user_responses_list:
        if answer_ordinal is None:
            ordinal_values.append(None)
        else:
            normalized = 2 * (answer_ordinal - min_val) / (max_val - min_val) - 1
            ordinal_values.append(normalized)
    embedding = np.array(ordinal_values, dtype=np.float32)
    return embedding


def get_embedding_eedi(model, tokenizer, user_responses_list, selected_questions_df):

    corrects = []
    answers = []
    
    for question_id, question_text, answer_formatted, answer_letter, answer_ordinal, max_val, min_val in user_responses_list:

        if answer_ordinal is None:
            corrects.append(None)  
        else:
            correct_answer = selected_questions_df[selected_questions_df['key'] == str(question_id)]['CorrectAnswer'].values[0]
            is_correct = int(answer_ordinal) == int(correct_answer)
            corrects.append(is_correct)

    embedding = np.array(corrects, dtype=np.float32)

    return embedding

def generate_and_save_embeddings(model, tokenizer, user_responses_dict, output_dir, user_ids, force_recompute, desc_suffix="", use_llm=False, data_dir=None):
    os.makedirs(output_dir, exist_ok=True)
    processed_count = 0
    error_count = 0
    print(f"Generating embeddings for {len(user_ids)} users ({desc_suffix})...")
    for user_id in tqdm(user_ids, desc=f"Generating embeddings ({desc_suffix})"):
        if user_id not in user_responses_dict or not user_responses_dict[user_id]:
            continue

        if 'user_' in str(user_id):
            user_file = os.path.join(output_dir, f"{user_id}.npy")
        else:
            user_file = os.path.join(output_dir, f"user_{user_id}.npy")
        if os.path.exists(user_file) and not force_recompute:
            try:
                _ = np.load(user_file) 
                processed_count += 1
                continue
            except Exception as load_err:
                print(f"Could not load existing embedding for user {user_id} ({desc_suffix}): {load_err}. Regenerating...")
        try:
            if not use_llm:
                embedding = get_embedding(model, tokenizer, user_responses_dict[user_id])
            else:
                embedding = get_embedding_llm(model, tokenizer, user_responses_dict[user_id], data_dir)
            if embedding is not None:
                np.save(user_file, embedding)
                processed_count += 1
            else:
                print(f"Failed to generate embedding for user {user_id} ({desc_suffix}). Skipping.")
                error_count += 1
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception as e:
            print(f"Error processing user {user_id} during embedding ({desc_suffix}): {e}")
            error_count += 1
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            continue

    print(f"Finished embedding generation ({desc_suffix}). Successfully generated/loaded embeddings for {processed_count} users.")
    if error_count > 0:
        print(f"Encountered errors during embedding generation ({desc_suffix}) for {error_count} users.")
    return processed_count, error_count

def compute_population_dist(human_list_embeddings, agent_list_embeddings, distance, method="sum", n_jobs=-1):
    if isinstance(human_list_embeddings, dict):
        human_list_embeddings = list(human_list_embeddings.values())
    def min_dist(human_embedding):
        return min(
            compute_dist_score(human_embedding, agent_embedding, distance)
            for agent_embedding in agent_list_embeddings
        )

    min_dists = Parallel(n_jobs=n_jobs)(
        delayed(min_dist)(human_embedding)
        for human_embedding in human_list_embeddings
    )

    return min(min_dists) if method == "min" else sum(min_dists)
    
def compute_dist_score(emb1, emb2, distance=None):        
    emb1 = emb1.flatten()
    emb2 = emb2.flatten()
    valid = ~(np.isnan(emb1) | np.isnan(emb2))
    total_coords = emb1.shape[0]
    valid_coords = np.sum(valid)
    if valid_coords == 0:
        return float('inf')  
    if distance == "euclidean":
        weight = total_coords / valid_coords
        squared_diff = (emb1[valid] - emb2[valid]) ** 2
        return np.sqrt(weight * np.sum(squared_diff))
    elif distance == "manhattan":
        weight = total_coords / valid_coords
        manhattan_dist = np.sum(np.abs(emb1[valid] - emb2[valid]))
        return weight * manhattan_dist

