from code.prompts import *
import argparse
import numpy as np
import os
import torch
import pandas as pd
from typing import Dict, List, Tuple, Any
from pathlib import Path
from code.utils import compute_dist_score, compute_population_dist
from code.agents import init_agent, init_agent_wikiarts, get_agent_responses, get_agent_responses_wikiarts
import json
import time
import joblib
import pickle
from code.utils import get_embedding, get_embedding_llm, get_embedding_eedi
from tqdm import tqdm
from itertools import combinations
import random
import math
import copy

def run_samplegreedy(args: argparse.Namespace,
                          human_embeddings: np.ndarray,
                          train_questions: List,
                          test_questions: List,
                          pd_questions: pd.DataFrame,
                          user_responses_train: Dict,
                          model,
                          tokenizer,
                          reducer,
                          user_ids: List[str],
                          domain=None) -> Tuple[Dict, List, List]:

    # Set seed
    random.seed(args.seed)

    all_agent_embeddings_reduced = []
    dist_scores = []
    selected_contexts = {}
    used_context_indices = set()  

    contexts_file = os.path.join(args.output_dir, f'selected_contexts_samplegreedy.json')
    
    all_valid_responses = []
    for idx, user_id_str in enumerate(user_ids):
        user_id = (user_id_str.replace('user_', ''))
        if user_id not in user_responses_train:
            user_id = int(user_id_str.replace('user_', ''))
            
        user_responses = user_responses_train[user_id]
        for i, resp in enumerate(user_responses):
            if resp[2] != "":  # Valid response
                all_valid_responses.append((user_id_str, i, resp))
    
    
    if len(all_valid_responses) < args.k_examples:
        return {}, [], []
    
    num_contexts = len(user_ids)
    
    iter_start_time = time.time()
    candidate_contexts = []
    
    for i in range(num_contexts):
        sampled_responses = random.sample(all_valid_responses, args.k_examples)
        context = ""
        context_list = []
        
        if domain == 'wikiarts':
            for user_id_str, original_index, response in sampled_responses:
                image_path = response[1]
                answer = response[2]
                context_list.append({"type": "text", "text": "Painting:\n"})
                context_list.append({"type": "image", "image": image_path})
                context_list.append({"type": "text", "text": f"Response: {answer}\n"})
            
            final_context_list = context_list.copy()
            final_context_list.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
            final_context = final_context_list
        else:
            for user_id_str, original_index, response in sampled_responses:
                context += f"Question: {response[1]}\nAnswer: {response[2]}\n\n"
                
            if args.role_play:
                if args.domain == 'opinionqa':
                    final_context = context + OPIONQA_INSTRUCTION
                elif args.domain == 'eedi':
                    final_context = context + EEDI_INSTRUCTION
                elif args.domain == 'wikiarts':
                    final_context = context + WIKIARTS_INSTRUCTION
            else:
                final_context = context + "Answer the following question.\n\n"
        
        candidate_contexts.append({
            "context": final_context,
            "examples": sampled_responses,
            "context_id": i
        })
    
    
    for i, candidate in enumerate(candidate_contexts):
        try:
            if domain == 'wikiarts':
                results = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, candidate["context"])
                qa_pairs = []
                for qid, question, answer, emotion, explanation in results:
                    qa_pairs.append((qid, question, answer, emotion, explanation))
                data_dir = os.path.join(args.output_dir, "train")
                agent_emb = get_embedding_llm(model, tokenizer, qa_pairs, data_dir)
                agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
            else:
                results = get_agent_responses(model, tokenizer, train_questions, pd_questions, candidate["context"])
                qa_pairs = []
                for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in results:
                    try:
                        answer = '{"' + answer_key + '":"' + answer_text + '"}'
                    except Exception:
                        answer = ''
                    qa_pairs.append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val))
                
                if args.domain == 'opinionqa':
                    agent_emb_reduced = get_embedding(model, tokenizer, qa_pairs)
                elif args.domain == 'eedi':
                    agent_emb_reduced = get_embedding_eedi(model, tokenizer, qa_pairs, pd_questions)
            
            candidate["agent_embedding"] = agent_emb_reduced
            candidate["train_responses"] = results
            
        except Exception as e:
            candidate["invalid"] = True
        
        torch.cuda.empty_cache()
    
    candidate_contexts = [c for c in candidate_contexts if not c.get("invalid", False)]
    
    for t in range(args.n_agents):
        iter_dir = f"iter_{t+1}"
        print(f"\n--- Iteration {t + 1}/{args.n_agents} ---")
        
        best_dist = float('inf')
        best_context_idx = None
        
        current_dist = None
        if all_agent_embeddings_reduced:
            current_dist = compute_population_dist(
                human_embeddings,
                np.vstack(all_agent_embeddings_reduced),
                args.distance
            )
        
        for candidate in candidate_contexts:
            context_idx = candidate["context_id"]
            if context_idx in used_context_indices:
                continue  
                
            agent_emb = candidate["agent_embedding"]
            
            if all_agent_embeddings_reduced:
                combined_embeddings = np.vstack(all_agent_embeddings_reduced + [agent_emb])
            else:
                combined_embeddings = agent_emb.reshape(1, -1) if len(agent_emb.shape) == 1 else agent_emb
                
            potential_dist = compute_population_dist(
                human_embeddings,
                combined_embeddings,
                args.distance
            )
            
            if potential_dist < best_dist:
                best_dist = potential_dist
                best_context_idx = context_idx
        
        if best_context_idx is None:
            break
            
        used_context_indices.add(best_context_idx)
        
        selected_candidate = next(c for c in candidate_contexts if c["context_id"] == best_context_idx)
        final_context = selected_candidate["context"]
        
        final_agent_embedding_reduced = selected_candidate["agent_embedding"]
        train_results = selected_candidate["train_responses"]
        
        all_agent_embeddings_reduced.append(final_agent_embedding_reduced)
        emb_stack = np.vstack(all_agent_embeddings_reduced)
        actual_population_dist = compute_population_dist(
            human_embeddings,
            emb_stack,
            args.distance
        )
        dist_scores.append(actual_population_dist)
        
        train_responses_output_path = f'{args.output_dir}/iterations/{iter_dir}/samplegreedy_train_responses.pkl'
        os.makedirs(os.path.dirname(train_responses_output_path), exist_ok=True)
        with open(train_responses_output_path, 'wb') as f:
            pickle.dump(train_results, f)
            
        selected_contexts[iter_dir] = {
            "context": final_context,
            "population_dist": float(actual_population_dist),
            "context_id": best_context_idx,
            "selected_examples": [{
                "user_id": example[0],
                "example_index": example[1],
                "question_id": example[2][0],
                "question": example[2][1],
                "answer": example[2][2]
            } for example in selected_candidate["examples"]]
        }
        
        with open(contexts_file, 'w') as f:
            json.dump(selected_contexts, f, indent=2)
        
        iter_runtime = time.time() - iter_start_time
        logs_path = f'{args.output_dir}/iterations/{iter_dir}/logs.json'
        os.makedirs(os.path.dirname(logs_path), exist_ok=True)
        
        logs = {}
        if os.path.exists(logs_path):
            with open(logs_path, 'r') as f:
                logs = json.load(f)
        
        logs['runtime_samplegreedy'] = iter_runtime
        with open(logs_path, 'w') as f:
            json.dump(logs, f, indent=2)
        
        try:
            if domain == 'wikiarts':
                results_test = get_agent_responses_wikiarts(model, tokenizer, test_questions, pd_questions, final_context)
            else:
                results_test = get_agent_responses(model, tokenizer, test_questions, pd_questions, final_context)
            test_output_path = f'{args.output_dir}/iterations/{iter_dir}/samplegreedy_test_responses.pkl'
            with open(test_output_path, 'wb') as f:
                pickle.dump(results_test, f)
        except Exception as e:
            print(f"Error generating test responses for iteration {t+1}: {str(e)}")
            
    return selected_contexts, all_agent_embeddings_reduced, dist_scores