from code.prompts import *
import argparse
import numpy as np
import os
import torch
import pandas as pd
from typing import Dict, List, Tuple, Any
from pathlib import Path
from code.utils import compute_dist_score, compute_population_dist
from code.agents import init_agent, init_agent_wikiarts, get_agent_responses, get_agent_responses_wikiarts
import json
import time
import joblib
import pickle
from code.utils import get_embedding, get_embedding_llm, get_embedding_eedi
from tqdm import tqdm
from itertools import combinations
import random
import math
import copy

def run_reppopmapped_two(args: argparse.Namespace,
                          human_embeddings: np.ndarray,
                          train_questions: List,
                          test_questions: List,
                          pd_questions: pd.DataFrame,
                          user_responses_train: Dict,
                          model,
                          tokenizer,
                          reducer,
                          user_ids: List[str],
                          domain=None) -> Tuple[Dict, List, List]:
    # Set seed
    random.seed(args.seed)

    all_agent_embeddings_reduced = []
    dist_scores = []
    selected_contexts = {}
    used_humans = set()  

    dist_cache = {}
    total_cache_hits = 0
    total_cache_attempts = 0

    contexts_file = os.path.join(args.output_dir, f'selected_contexts_reppopmapped_two.json')
    
    print("Pre-computing proxy agents with greedily constructed contexts...")
    candidate_agents = {}  
    
    iter_start_time = time.time()
    

    for idx, user_id_str in tqdm(enumerate(user_ids), total=len(user_ids), desc="Processing users"):
        user_id = (user_id_str.replace('user_', ''))
        if user_id not in user_responses_train:
            user_id = int(user_id_str.replace('user_', ''))
            
        user_responses = user_responses_train[user_id]
        valid_responses_with_indices = [(i, resp) for i, resp in enumerate(user_responses) if resp[2] != ""]
        
        if len(valid_responses_with_indices) < args.k_examples:
            print(f"Skipping {user_id_str}: has only {len(valid_responses_with_indices)} valid responses < {args.k_examples}")
            continue
            
        print(f"Building greedy context for human {user_id_str}...")
        
        target_human_embedding = human_embeddings[idx]
        
        # Start with empty context
        current_context_str = ""
        current_context_list = []
        selected_examples = []
        used_questions = set()
        
        # Greedily add examples to context
        for k in range(args.k_examples):
            print(f"  Selecting example {k+1}/{args.k_examples}...")
            
            best_dist_to_human = float('inf')
            best_example = None
            local_cache_hits = 0
            local_cache_attempts = 0
            
            available_responses = [
                (i, resp) for i, resp in valid_responses_with_indices
                if resp[0] not in used_questions 
            ]
            
            if not available_responses:
                print(f"  No more available responses for human {user_id_str}")
                break
                
            num_samples = min(args.sampling_size, len(available_responses))
            sampled_responses = random.sample(available_responses, num_samples)
            
            for original_idx, response in sampled_responses:
                # Build temporary context with this example added
                if domain == 'wikiarts':
                    temp_context_list = current_context_list.copy()
                    image_path = response[1]
                    answer = response[2]
                    temp_context_list.append({"type": "text", "text": "Painting:\n"})
                    temp_context_list.append({"type": "image", "image": image_path})
                    temp_context_list.append({"type": "text", "text": f"Response: {answer}\n"})
                    temp_final_context = temp_context_list.copy()
                    temp_final_context.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
                    cache_key = None
                else:
                    temp_context = current_context_str + f"Question: {response[1]}\nAnswer: {response[2]}\n\n"
                    if args.role_play:
                        if args.domain == 'opinionqa':
                            temp_final_context = temp_context + OPIONQA_INSTRUCTION
                        elif args.domain == 'eedi':
                            temp_final_context = temp_context + EEDI_INSTRUCTION
                    else:
                        temp_final_context = temp_context + "Answer the following question.\n\n"
                    
                    cache_key = temp_final_context
                
                try:
                    # Get agent responses and embedding with this temporary context
                    agent_emb_reduced = None
                    
                    if domain != 'wikiarts' and cache_key is not None:
                        local_cache_attempts += 1
                        total_cache_attempts += 1
                        
                        if cache_key in dist_cache:
                            agent_emb_reduced = dist_cache[cache_key]
                            local_cache_hits += 1
                            total_cache_hits += 1
                        else:
                            results = get_agent_responses(model, tokenizer, train_questions, pd_questions, temp_final_context)
                            qa_pairs = []
                            for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in results:
                                try:
                                    answer = '{"' + answer_key + '":"' + answer_text + '"}'
                                except Exception:
                                    answer = ''
                                qa_pairs.append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val))
                            
                            if args.domain == 'opinionqa':
                                agent_emb_reduced = get_embedding(model, tokenizer, qa_pairs)
                            elif args.domain == 'eedi':
                                agent_emb_reduced = get_embedding_eedi(model, tokenizer, qa_pairs, pd_questions)
                            
                            dist_cache[cache_key] = agent_emb_reduced
                    else:
                        if domain == 'wikiarts':
                            results = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, temp_final_context)
                            qa_pairs = []
                            for qid, question, answer, emotion, explanation in results:
                                qa_pairs.append((qid, question, answer, emotion, explanation))
                            data_dir = os.path.join(args.output_dir, "train")
                            agent_emb = get_embedding_llm(model, tokenizer, qa_pairs, data_dir)
                            agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
                    
                    dist_to_human = compute_dist_score(
                        agent_emb_reduced,
                        target_human_embedding,
                        args.distance
                    )
                    
                    if dist_to_human < best_dist_to_human:
                        best_dist_to_human = dist_to_human
                        best_example = (original_idx, response)
                        
                except Exception as e:
                    print(f"    Error evaluating example: {str(e)}")
                    continue
            
            if local_cache_attempts > 0:
                print(f"    Cache hits: {local_cache_hits}/{local_cache_attempts} ({local_cache_hits/local_cache_attempts*100:.1f}%)")
            
            if best_example:
                original_idx, response = best_example
                question_id = response[0]
                
                if domain == 'wikiarts':
                    image_path = response[1]
                    answer = response[2]
                    current_context_list.append({"type": "text", "text": "Painting:\n"})
                    current_context_list.append({"type": "image", "image": image_path})
                    current_context_list.append({"type": "text", "text": f"Response: {answer}\n"})
                else:
                    current_context_str += f"Question: {response[1]}\nAnswer: {response[2]}\n\n"
                
                selected_examples.append((user_id_str, original_idx, response))
                used_questions.add(question_id)
                print(f"    Selected example with dist score: {best_dist_to_human:.4f}")
            else:
                print(f"    Could not find a suitable example to add. Stopping.")
                break
        
        # Create final context
        if domain == 'wikiarts':
            final_context = current_context_list.copy()
            final_context.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
            cache_key = None  
        else:
            if args.role_play:
                if args.domain == 'opinionqa':
                    final_context = current_context_str + OPIONQA_INSTRUCTION
                elif args.domain == 'eedi':
                    final_context = current_context_str + EEDI_INSTRUCTION
                elif args.domain == 'wikiarts':
                    final_context = current_context_str + WIKIARTS_INSTRUCTION
            else:
                final_context = current_context_str + "Answer the following question.\n\n"
            
            cache_key = final_context
        
        # Get agent embedding for final context
        try:
            agent_emb_reduced = None
            if domain != 'wikiarts' and cache_key is not None and cache_key in dist_cache:
                agent_emb_reduced = dist_cache[cache_key]
                total_cache_hits += 1
            else:
                if domain == 'wikiarts':
                    results = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, final_context)
                    qa_pairs = []
                    for qid, question, answer, emotion, explanation in results:
                        qa_pairs.append((qid, question, answer, emotion, explanation))
                    data_dir = os.path.join(args.output_dir, "train")
                    agent_emb = get_embedding_llm(model, tokenizer, qa_pairs, data_dir)
                    agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
                else:
                    results = get_agent_responses(model, tokenizer, train_questions, pd_questions, final_context)
                    qa_pairs = []
                    for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in results:
                        try:
                            answer = '{"' + answer_key + '":"' + answer_text + '"}'
                        except Exception:
                            answer = ''
                        qa_pairs.append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val))
                    
                    if args.domain == 'opinionqa':
                        agent_emb_reduced = get_embedding(model, tokenizer, qa_pairs)
                    elif args.domain == 'eedi':
                        agent_emb_reduced = get_embedding_eedi(model, tokenizer, qa_pairs, pd_questions)
                    
                    if cache_key is not None:
                        dist_cache[cache_key] = agent_emb_reduced
            
            try:
                final_dist_to_human = compute_dist_score(
                    agent_emb_reduced,
                    target_human_embedding,
                    args.distance
                )
            except Exception as e:
                final_dist_to_human = float('inf')
            
            candidate_agents[user_id_str] = {
                "context": final_context,
                "examples": selected_examples,
                "agent_embedding": agent_emb_reduced,
                "train_responses": results,
                "dist_to_human": final_dist_to_human
            }
            
            
        except Exception as e:
            continue
        
        torch.cuda.empty_cache()
    
    for t in range(args.n_agents):
        iter_dir = f"iter_{t+1}"
        print(f"\n--- Iteration {t + 1}/{args.n_agents} ---")
        
        best_dist = float('inf')
        best_human = None
        
        current_dist = None
        if all_agent_embeddings_reduced:
            current_dist = compute_population_dist(
                human_embeddings,
                np.vstack(all_agent_embeddings_reduced),
                args.distance
            )
        
        # Evaluate each proxy agent
        for human_id, candidate in candidate_agents.items():
            if human_id in used_humans:
                continue  
                
            agent_emb = candidate["agent_embedding"]
            
            if all_agent_embeddings_reduced:
                combined_embeddings = np.vstack(all_agent_embeddings_reduced + [agent_emb])
            else:
                combined_embeddings = agent_emb.reshape(1, -1) if len(agent_emb.shape) == 1 else agent_emb
                
            potential_dist = compute_population_dist(
                human_embeddings,
                combined_embeddings,
                args.distance
            )
            
            if potential_dist < best_dist:
                best_dist = potential_dist
                best_human = human_id
        
        if best_human is None:
            break
            
        used_humans.add(best_human)
        selected_candidate = candidate_agents[best_human]
        final_context = selected_candidate["context"]
        
        
        final_agent_embedding_reduced = selected_candidate["agent_embedding"]
        train_results = selected_candidate["train_responses"]
        dist_to_human = selected_candidate["dist_to_human"]
        
        all_agent_embeddings_reduced.append(final_agent_embedding_reduced)
        emb_stack = np.vstack(all_agent_embeddings_reduced)
        actual_population_dist = compute_population_dist(
            human_embeddings,
            emb_stack,
            args.distance
        )
        dist_scores.append(actual_population_dist) 
        
        train_responses_output_path = f'{args.output_dir}/iterations/{iter_dir}/reppopmapped_two_train_responses.pkl'
        os.makedirs(os.path.dirname(train_responses_output_path), exist_ok=True)
        with open(train_responses_output_path, 'wb') as f:
            pickle.dump(train_results, f)
            
        selected_contexts[iter_dir] = {
            "context": final_context,
            "population_dist": float(actual_population_dist),
            "human_id": best_human,
            "human_dist": float(dist_to_human),
            "selected_examples": [{
                "user_id": example[0],
                "example_index": example[1],
                "question_id": example[2][0],
                "question": example[2][1],
                "answer": example[2][2]
            } for example in selected_candidate["examples"]]
        }
        
        with open(contexts_file, 'w') as f:
            json.dump(selected_contexts, f, indent=2)
        
        iter_runtime = time.time() - iter_start_time
        logs_path = f'{args.output_dir}/iterations/{iter_dir}/logs.json'
        os.makedirs(os.path.dirname(logs_path), exist_ok=True)
        
        logs = {}
        if os.path.exists(logs_path):
            with open(logs_path, 'r') as f:
                logs = json.load(f)
        
        logs['runtime_reppopmapped_two'] = iter_runtime
        with open(logs_path, 'w') as f:
            json.dump(logs, f, indent=2)
        
        try:
            if domain == 'wikiarts':
                results_test = get_agent_responses_wikiarts(model, tokenizer, test_questions, pd_questions, final_context)
            else:
                results_test = get_agent_responses(model, tokenizer, test_questions, pd_questions, final_context)
            test_output_path = f'{args.output_dir}/iterations/{iter_dir}/reppopmapped_two_test_responses.pkl'
            with open(test_output_path, 'wb') as f:
                pickle.dump(results_test, f)
        except Exception as e:
            print(f"Error generating test responses for iteration {t+1}: {str(e)}")
            
    return selected_contexts, all_agent_embeddings_reduced, dist_scores


def run_reppopmapped_one(args: argparse.Namespace,
                          human_embeddings: np.ndarray,
                          train_questions: List,
                          test_questions: List,
                          pd_questions: pd.DataFrame,
                          user_responses_train: Dict,
                          model,
                          tokenizer,
                          reducer,
                          user_ids: List[str],
                          domain=None) -> Tuple[Dict, List, List]:
    # Set seed
    random.seed(args.seed)

    all_agent_embeddings_reduced = []
    dist_scores = []
    selected_contexts = {}
    used_humans = set()  
    
    dist_cache = {}
    total_cache_hits = 0
    total_cache_attempts = 0

    contexts_file = os.path.join(args.output_dir, f'selected_contexts_reppopmapped_one.json')
    
    iter_start_time = time.time()
    
    candidate_agents = {} 
    
    from tqdm import tqdm
    for idx, user_id_str in tqdm(enumerate(user_ids), total=len(user_ids), desc="Processing users"):
        user_id = (user_id_str.replace('user_', ''))
        if user_id not in user_responses_train:
            user_id = int(user_id_str.replace('user_', ''))
            
        user_responses = user_responses_train[user_id]
        valid_responses = [(i, resp) for i, resp in enumerate(user_responses) if resp[2] != ""]
        
        if len(valid_responses) < args.k_examples:
            continue
            
        
        sampled_responses = random.sample(valid_responses, args.k_examples)
        context = ""
        context_list = []
        examples = []
        
        if domain == 'wikiarts':
            for original_index, response in sampled_responses:
                image_path = response[1]
                answer = response[2]
                context_list.append({"type": "text", "text": "Painting:\n"})
                context_list.append({"type": "image", "image": image_path})
                context_list.append({"type": "text", "text": f"Response: {answer}\n"})
                examples.append((user_id_str, original_index, response))
            
            final_context_list = context_list.copy()
            final_context_list.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
            final_context = final_context_list
        else:
            for original_index, response in sampled_responses:
                context += f"Question: {response[1]}\nAnswer: {response[2]}\n\n"
                examples.append((user_id_str, original_index, response))
                
            if args.role_play:
                if args.domain == 'opinionqa':
                    final_context = context + OPIONQA_INSTRUCTION
                elif args.domain == 'eedi':
                    final_context = context + EEDI_INSTRUCTION
                elif args.domain == 'wikiarts':
                    final_context = context + WIKIARTS_INSTRUCTION
            else:
                final_context = context + "Answer the following question.\n\n"
        
        try:
            if domain == 'wikiarts':
                results = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, final_context)
                qa_pairs = []
                for qid, question, answer, emotion, explanation in results:
                    qa_pairs.append((qid, question, answer, emotion, explanation))
                data_dir = os.path.join(args.output_dir, "train")
                agent_emb = get_embedding_llm(model, tokenizer, qa_pairs, data_dir)
                agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
            else:
                cache_key = final_context
                total_cache_attempts += 1
                
                if cache_key in dist_cache:
                    cached_data = dist_cache[cache_key]
                    agent_emb_reduced = cached_data["embedding"]
                    results = cached_data["responses"]
                    total_cache_hits += 1
                else:
                    results = get_agent_responses(model, tokenizer, train_questions, pd_questions, final_context)
                    qa_pairs = []
                    for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in results:
                        try:
                            answer = '{"' + answer_key + '":"' + answer_text + '"}'
                        except Exception:
                            answer = ''
                        qa_pairs.append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val))
                    
                    if args.domain == 'opinionqa':
                        agent_emb_reduced = get_embedding(model, tokenizer, qa_pairs)
                    elif args.domain == 'eedi':
                        agent_emb_reduced = get_embedding_eedi(model, tokenizer, qa_pairs, pd_questions)
                    
                    dist_cache[cache_key] = {
                        "embedding": agent_emb_reduced,
                        "responses": results
                    }
            
            candidate_agents[user_id_str] = {
                "context": final_context,
                "examples": examples,
                "agent_embedding": agent_emb_reduced,
                "train_responses": results  
            }
            
        except Exception as e:
            continue
        
        torch.cuda.empty_cache()
    
    for t in range(args.n_agents):
        iter_dir = f"iter_{t+1}"
        print(f"\n--- Iteration {t + 1}/{args.n_agents} ---")
        
        best_dist = float('inf')
        best_human = None
        
        current_dist = None
        if all_agent_embeddings_reduced:
            current_dist = compute_population_dist(
                human_embeddings,
                np.vstack(all_agent_embeddings_reduced),
                args.distance
            )
        
        for human_id, candidate in candidate_agents.items():
            if human_id in used_humans:
                continue  
                
            agent_emb = candidate["agent_embedding"]
            
            if all_agent_embeddings_reduced:
                combined_embeddings = np.vstack(all_agent_embeddings_reduced + [agent_emb])
            else:
                combined_embeddings = agent_emb.reshape(1, -1) if len(agent_emb.shape) == 1 else agent_emb
                
            potential_dist = compute_population_dist(
                human_embeddings,
                combined_embeddings,
                args.distance
            )
            
            if potential_dist < best_dist:
                best_dist = potential_dist
                best_human = human_id
        
        if best_human is None:
            break
            
        used_humans.add(best_human)
        selected_candidate = candidate_agents[best_human]
        final_context = selected_candidate["context"]
        
        final_agent_embedding_reduced = selected_candidate["agent_embedding"]
        train_results = selected_candidate["train_responses"]
        
        all_agent_embeddings_reduced.append(final_agent_embedding_reduced)
        emb_stack = np.vstack(all_agent_embeddings_reduced)
        actual_population_dist = compute_population_dist(
            human_embeddings,
            emb_stack,
            args.distance
        )
        dist_scores.append(actual_population_dist)
        
        train_responses_output_path = f'{args.output_dir}/iterations/{iter_dir}/reppopmapped_one_train_responses.pkl'
        os.makedirs(os.path.dirname(train_responses_output_path), exist_ok=True)
        with open(train_responses_output_path, 'wb') as f:
            pickle.dump(train_results, f)
            
        selected_contexts[iter_dir] = {
            "context": final_context,
            "population_dist": float(actual_population_dist),
            "human_id": best_human,
            "selected_examples": [{
                "user_id": example[0],
                "example_index": example[1],
                "question_id": example[2][0],
                "question": example[2][1],
                "answer": example[2][2]
            } for example in selected_candidate["examples"]]
        }
        
        with open(contexts_file, 'w') as f:
            json.dump(selected_contexts, f, indent=2)
        
        iter_runtime = time.time() - iter_start_time
        logs_path = f'{args.output_dir}/iterations/{iter_dir}/logs.json'
        os.makedirs(os.path.dirname(logs_path), exist_ok=True)
        
        logs = {}
        if os.path.exists(logs_path):
            with open(logs_path, 'r') as f:
                logs = json.load(f)
        
        logs['runtime_reppopmapped_one'] = iter_runtime
        with open(logs_path, 'w') as f:
            json.dump(logs, f, indent=2)
        
        try:
            if domain == 'wikiarts':
                results_test = get_agent_responses_wikiarts(model, tokenizer, test_questions, pd_questions, final_context)
            else:
                results_test = get_agent_responses(model, tokenizer, test_questions, pd_questions, final_context)
            test_output_path = f'{args.output_dir}/iterations/{iter_dir}/reppopmapped_one_test_responses.pkl'
            with open(test_output_path, 'wb') as f:
                pickle.dump(results_test, f)
        except Exception as e:
            print(f"Error generating test responses for iteration {t+1}: {str(e)}")
            
    return selected_contexts, all_agent_embeddings_reduced, dist_scores