from code.prompts import *
import argparse
import numpy as np
import os
import torch
import pandas as pd
from typing import Dict, List, Tuple, Any
from pathlib import Path
from code.utils import compute_dist_score, compute_population_dist
from code.agents import init_agent, init_agent_wikiarts, get_agent_responses, get_agent_responses_wikiarts
import json
import time
import joblib
import pickle
from code.utils import get_embedding, get_embedding_llm, get_embedding_eedi
from tqdm import tqdm
from itertools import combinations
import random
import math
import copy

def run_reppopdemo(args: argparse.Namespace,
                        human_embeddings: np.ndarray,
                        train_questions: List,
                        test_questions: List,
                        pd_questions: pd.DataFrame,
                        user_responses_train: Dict,
                        model,
                        tokenizer,
                        reducer,
                        domain=None) -> Tuple[Dict, List, List]:

    #set seed
    random.seed(args.seed)

    all_agent_embeddings_reduced = []
    dist_scores = []
    selected_contexts = {}

    # Save contexts to json file with algorithm name 
    contexts_file = os.path.join(args.output_dir, f'selected_contexts_reppopdemo.json')

    # Create dictionary of responses by question
    question_responses = {}
    for user_id, responses in user_responses_train.items():
        for idx, response in enumerate(responses):
            question_id = response[0]
            if response[2] != "":  # Skip empty answers
                if question_id not in question_responses:
                    question_responses[question_id] = []
                question_responses[question_id].append((user_id, idx, response))

    for t in tqdm(range(args.n_agents), desc="Processing iterations", total=args.n_agents):
        print(f"\nIteration {t + 1}/{args.n_agents}")
        iter_start_time = time.time()

        current_context = ""
        current_context_list = []
        selected_examples = []
        used_questions = set()

        # Select args.k_examples pairs
        for k in range(args.k_examples):
            dist_cache = {}
            print(f"\nSelecting example {k+1}/{args.k_examples}")

            best_dist_for_step_k = float('inf') 
            best_example_for_step_k = None
            best_question_for_step_k = None 


            available_pairs = []
            available_questions = [q for q in train_questions if q not in used_questions]

            for question in available_questions:
                if question in question_responses:
                    available_pairs.extend([(question, resp) for resp in question_responses[question]])

            if not available_pairs:
                print("No more available pairs")
                break


            num_samples = min(args.sampling_size, len(available_pairs))
            sampled_pairs = random.sample(available_pairs, num_samples)

            print(f"Evaluating {len(sampled_pairs)} sampled pairs")
            cache_hits = 0

            # Evaluate each sampled pair
            for question, candidate in sampled_pairs:

                assert domain is not None, "Domain must be provided for context building"
                if domain == 'wikiarts':
                    image_path = candidate[2][1]
                    answer = candidate[2][2]
                    temp_context_list = current_context_list.copy()
                    temp_context_list.append({"type": "text", "text": "Painting:\n"})
                    temp_context_list.append({"type": "image", "image": image_path})
                    temp_context_list.append({"type": "text", "text": f"Response: {answer}\n"})
                    temp_final_context = temp_context_list.copy()
                    temp_final_context.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
                    

                    cache_key = None
                else:
                    temp_context = current_context + f"Question: {candidate[2][1]}\nAnswer: {candidate[2][2]}\n\n"

                    if args.role_play:
                        if args.domain == 'opinionqa':
                            temp_final_context = temp_context + OPIONQA_INSTRUCTION
                        elif args.domain == 'eedi':
                            temp_final_context = temp_context + EEDI_INSTRUCTION
                        elif args.domain == 'wikiarts':
                            temp_final_context = temp_context + WIKIARTS_INSTRUCTION
                    else:
                        temp_final_context = temp_context + "Answer the following question.\n\n"
                    
                    cache_key = (candidate[2][1], candidate[2][2])

                try:
                    agent_emb_reduced_candidate = None
                    current_dist = None
                    
                    if domain != 'wikiarts' and cache_key in dist_cache:
                        current_dist, agent_emb_reduced_candidate = dist_cache[cache_key]
                        cache_hits += 1
                    else:
                        if domain == 'wikiarts':
                            results_candidate = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, temp_final_context)
                            qa_pairs_candidate = []
                            for qid, question, answer, emotion, explanation in results_candidate:
                                qa_pairs_candidate.append((qid, question, answer, emotion, explanation))
                            data_dir=os.path.join(args.output_dir, "train")
                            agent_emb = get_embedding_llm(model, tokenizer, qa_pairs_candidate, data_dir)
                            agent_emb_reduced_candidate = reducer.transform(agent_emb.reshape(1, -1))
                            # print("agent_emb_reduced_candidate", agent_emb_reduced_candidate.shape)
                        else:
                            results_candidate = get_agent_responses(model, tokenizer, train_questions, pd_questions, temp_final_context)
                            qa_pairs_candidate = []
                            for qid, q, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in results_candidate:
                                try:
                                    answer = '{"' + answer_key + '":"' + answer_text + '"}'
                                except Exception as e:
                                    answer = f''
                                qa_pairs_candidate.append((qid, q, answer, answer_key, answer_ordinal, max_val, min_val))

                            # Get agent embedding
                            if args.domain == 'opinionqa':
                                agent_emb_reduced_candidate = get_embedding(model, tokenizer, qa_pairs_candidate)
                            elif args.domain == 'eedi':
                                agent_emb_reduced_candidate = get_embedding_eedi(model, tokenizer, qa_pairs_candidate, pd_questions)

                        embeddings_to_evaluate = all_agent_embeddings_reduced + [agent_emb_reduced_candidate]
                        if embeddings_to_evaluate:
                            current_dist = compute_population_dist(
                                human_embeddings,
                                np.vstack(embeddings_to_evaluate),
                                args.distance
                            )
                        else: 
                            current_dist = compute_population_dist(
                                human_embeddings,
                                agent_emb_reduced_candidate.reshape(1, -1),
                                args.distance
                            )
                        
                        if domain != 'wikiarts' and cache_key is not None:
                            dist_cache[cache_key] = (current_dist, agent_emb_reduced_candidate)

                    if current_dist < best_dist_for_step_k:
                        best_dist_for_step_k = current_dist
                        best_example_for_step_k = candidate
                        best_question_for_step_k = question 
                        # print(f"New best dist for step {k+1}: {best_dist_for_step_k:.4f} from user {candidate[0]}")

                except Exception as e:
                    # print(f"Error evaluating pair ({candidate[0]}, Q: {candidate[2][0]}): {str(e)}")
                    continue
            
            # if cache_hits > 0:
            #     print(f"Cache hits: {cache_hits}/{len(sampled_pairs)} ({cache_hits/len(sampled_pairs)*100:.1f}%)")

            if best_example_for_step_k is not None:
                if domain == 'wikiarts':
                    image_path = best_example_for_step_k[2][1]
                    answer = best_example_for_step_k[2][2]
                    current_context_list.append({"type": "text", "text": "Painting:\n"})
                    current_context_list.append({"type": "image", "image": image_path})
                    current_context_list.append({"type": "text", "text": f"Response: {answer}\n"})
                else:
                    current_context += f"Question: {best_example_for_step_k[2][1]}\nAnswer: {best_example_for_step_k[2][2]}\n\n"
                selected_examples.append(best_example_for_step_k)
                used_questions.add(best_question_for_step_k) 
                # print(f"Selected response from user {best_example_for_step_k[0]} for step {k+1}. (Best dist score during search: {best_dist_for_step_k:.4f})")
            else:
                # print(f"Warning: No valid example found for position {k+1}. Stopping example selection for this iteration.")
                break 

        if domain == 'wikiarts':
            final_context = current_context_list.copy()
            final_context.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
        else:
            if args.role_play:
                if args.domain == 'opinionqa':
                    final_context = current_context + OPIONQA_INSTRUCTION
                elif args.domain == 'eedi':
                    final_context = current_context + EEDI_INSTRUCTION
                elif args.domain == 'wikiarts':
                    final_context = current_context + WIKIARTS_INSTRUCTION
            else:
                final_context = current_context + "Answer the following question.\n\n"

        final_agent_embedding_reduced = None
        final_results = None
        try:
            # print(f"Generating final embedding for iteration {t+1}...")
            if domain == 'wikiarts':
                final_results = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, final_context)
                # Get embeddings
                final_qa_pairs = []
                for qid, question, answer, emotion, explanation in final_results:
                    final_qa_pairs.append((qid, question, answer, emotion, explanation))

                data_dir=os.path.join(args.output_dir, "train")
                agent_emb = get_embedding_llm(model, tokenizer, final_qa_pairs, data_dir)
                final_agent_embedding_reduced = reducer.transform(agent_emb.reshape(1, -1))
                # print("final_agent_embedding_reduced", final_agent_embedding_reduced.shape)
            else:
                final_results = get_agent_responses(model, tokenizer, train_questions, pd_questions, final_context)
                final_qa_pairs = []
                for qid, q, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in final_results:
                    try:
                        answer = '{"' + answer_key + '":"' + answer_text + '"}'
                    except Exception as e:
                        answer = f''
                    final_qa_pairs.append((qid, q, answer, answer_key, answer_ordinal, max_val, min_val))

                if args.domain == 'opinionqa':
                    final_agent_embedding_reduced = get_embedding(model, tokenizer, final_qa_pairs)
                elif args.domain == 'eedi':
                    final_agent_embedding_reduced = get_embedding_eedi(model, tokenizer, final_qa_pairs, pd_questions)

            train_responses_output_path = f'{args.output_dir}/iterations/iter_{t+1}/reppopdemo_train_responses.pkl'
            os.makedirs(os.path.dirname(train_responses_output_path), exist_ok=True)
            with open(train_responses_output_path, 'wb') as f:
                pickle.dump(final_results, f)

        except Exception as e:
            print(f"Error generating final results/embedding for iteration {t+1}: {str(e)}")

        if final_agent_embedding_reduced is not None:
            all_agent_embeddings_reduced.append(final_agent_embedding_reduced)
            emb_stack = np.vstack(all_agent_embeddings_reduced)
            current_population_dist = compute_population_dist(
                human_embeddings,
                emb_stack,
                args.distance
            )
            dist_scores.append(current_population_dist)
            # print(f"Iteration {t+1} final dist score: {current_population_dist:.4f}")

            selected_contexts[f"iter_{t+1}"] = {
                "context": final_context,
                "population_dist": float(current_population_dist),
                "selected_examples": [{
                    "user_id": example[0],
                    "example_index": example[1], 
                    "question_id": example[2][0],
                    "question": example[2][1],
                    "answer": example[2][2]
                } for example in selected_examples]
            }
        else:
            dist_scores.append(float('nan')) 
            selected_contexts[f"iter_{t+1}"] = { 
                "context": final_context,
                "population_dist": float('nan'),
                "selected_examples": [{
                    "user_id": example[0], "example_index": example[1], "question_id": example[2][0],
                    "question": example[2][1], "answer": example[2][2]
                } for example in selected_examples],
                "error": "Failed to generate final embedding"
            }


        with open(contexts_file, 'w') as f:
            json.dump(selected_contexts, f, indent=2)
    
        # Save runtime and logs
        iter_runtime = time.time() - iter_start_time
        logs_path = f'{args.output_dir}/iterations/iter_{t + 1}/logs.json'
    
        logs = {}
        if os.path.exists(logs_path):
            with open(logs_path, 'r') as f:
                logs = json.load(f)
    
        logs['runtime_reppopdemo'] = iter_runtime
        with open(logs_path, 'w') as f:
            json.dump(logs, f, indent=2)

        if final_context: 
            try:
                if domain == 'wikiarts':
                    results_test = get_agent_responses_wikiarts(model, tokenizer, test_questions, pd_questions, final_context)
                else:
                    results_test = get_agent_responses(model, tokenizer, test_questions, pd_questions, final_context)
                test_output_path = f'{args.output_dir}/iterations/iter_{t+1}/reppopdemo_test_responses.pkl'
                with open(test_output_path, 'wb') as f:
                    pickle.dump(results_test, f)
            except Exception as e:
                print(f"Error generating test responses for iteration {t+1}: {str(e)}")


        torch.cuda.empty_cache()

    return selected_contexts, all_agent_embeddings_reduced, dist_scores