from code.prompts import *
import argparse
import numpy as np
import os
import torch
import pandas as pd
from typing import Dict, List, Tuple, Any
from pathlib import Path
from code.utils import compute_dist_score, compute_population_dist
from code.agents import init_agent, init_agent_wikiarts, get_agent_responses, get_agent_responses_wikiarts
import json
import time
import joblib
import pickle
from code.utils import get_embedding, get_embedding_llm, get_embedding_eedi
from tqdm import tqdm
from itertools import combinations
import random
import math
import copy

def run_single(args: argparse.Namespace,
                             human_embeddings: np.ndarray,
                             train_questions: List,
                             test_questions: List,
                             pd_questions: pd.DataFrame,
                             user_responses_train: Dict,
                             model,
                             tokenizer,
                             reducer,
                             domain=None) -> Tuple[Dict, List, List]:
    #set seed
    random.seed(args.seed)

    all_agent_embeddings_reduced = []
    dist_scores = []
    selected_contexts = {}

    # Save contexts to json file
    contexts_file = os.path.join(args.output_dir, f'selected_contexts_single.json')

    # print("Building context...")
    context_build_start_time = time.time()
    context = ""
    context_list = []
    valid_examples = []
    num_examples_added = 0

    sampled_questions = random.sample(list(train_questions), args.k_examples)

    for question in sampled_questions:
        valid_responses_for_question = []
        for user_id in user_responses_train:
            for response in user_responses_train[user_id]:
                if response[0] == question and response[2] != "":
                    valid_responses_for_question.append((user_id, response))
                    break

        if valid_responses_for_question:
            selected_user_id, selected_response = random.choice(valid_responses_for_question)
            if domain == 'wikiarts':
                image_path = selected_response[1]
                answer = selected_response[2]
                context_list.append({"type": "text", "text": "Painting:\n"})
                context_list.append({"type": "image", "image": image_path})
                context_list.append({"type": "text", "text": f"Response: {answer}\n"})  
            else:
                context += f"Question: {selected_response[1]}\nAnswer: {selected_response[2]}\n\n"
            valid_examples.append((f"user_{selected_user_id}", selected_response[0], selected_response))
            num_examples_added += 1

    if domain == 'wikiarts':
        final_context = context_list.copy()
        final_context.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
    else:
        if args.role_play:
            if args.domain == 'opinionqa':
                final_context = context + OPIONQA_INSTRUCTION
            elif args.domain == 'eedi':
                final_context = context + EEDI_INSTRUCTION
            elif args.domain == 'wikiarts':
                final_context = context + WIKIARTS_INSTRUCTION
        else:
            final_context = context + "Answer the following question.\n\n"


    for t in range(args.n_agents):
        print(f"\nIteration {t + 1}/{args.n_agents}")
        iter_start_time = time.time()

        try:
            if domain == 'wikiarts':
                results = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, final_context)
                qa_pairs = []
                for qid, question, answer, emotion, explanation in results:
                    qa_pairs.append((qid, question, answer, emotion, explanation))
                # print("qa_pairs", qa_pairs)
                data_dir=os.path.join(args.output_dir, "train")
                agent_emb = get_embedding_llm(model, tokenizer, qa_pairs, data_dir)
                agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
                # print("agent_emb_reduced", agent_emb_reduced.shape)
            else:
                results = get_agent_responses(model, tokenizer, train_questions, pd_questions, final_context)

                # Get embeddings
                qa_pairs = []
                for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in results:
                    try:
                        answer = '{"' + answer_key + '":"' + answer_text + '"}'
                    except Exception as e:
                        answer = f''
                    qa_pairs.append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val))

                # Get agent embedding
                if args.domain == 'opinionqa':
                    agent_emb = get_embedding(model, tokenizer, qa_pairs)
                elif args.domain == 'eedi':
                    agent_emb = get_embedding_eedi(model, tokenizer, qa_pairs, pd_questions)
                
                agent_emb_reduced = agent_emb

            # Add agent embedding and calculate dist
            all_agent_embeddings_reduced.append(agent_emb_reduced)
            current_dist = compute_population_dist(
                human_embeddings,
                np.vstack(all_agent_embeddings_reduced),
                args.distance
            )
            dist_scores.append(current_dist)

            # Save the responses
            train_responses_output_path = f'{args.output_dir}/iterations/iter_{t+1}/single_train_responses.pkl'
            os.makedirs(os.path.dirname(train_responses_output_path), exist_ok=True)
            with open(train_responses_output_path, 'wb') as f:
                pickle.dump(results, f)

            selected_contexts[f"iter_{t+1}"] = {
                "context": final_context,
                "population_dist": float(current_dist),
                "selected_examples": [{
                    "user_id": example[0],
                    "example_index": example[1],
                    "question_id": example[2][0],
                    "question": example[2][1],
                    "answer": example[2][2]
                } for example in valid_examples]
            }

            with open(contexts_file, 'w') as f:
                json.dump(selected_contexts, f, indent=2)

            # Save runtime to logs
            iter_runtime = time.time() - iter_start_time
            logs_path = f'{args.output_dir}/iterations/iter_{t+1}/logs.json'
            os.makedirs(os.path.dirname(logs_path), exist_ok=True)

            logs = {}
            if os.path.exists(logs_path):
                with open(logs_path, 'r') as f:
                    logs = json.load(f)

            logs['runtime_single'] = iter_runtime
            with open(logs_path, 'w') as f:
                json.dump(logs, f, indent=2)

            # Get test set responses
            if domain == 'wikiarts':
                results_test = get_agent_responses_wikiarts(model, tokenizer, test_questions, pd_questions, final_context)
            else:
                results_test = get_agent_responses(model, tokenizer, test_questions, pd_questions, final_context)
            test_output_path = f'{args.output_dir}/iterations/iter_{t+1}/single_test_responses.pkl'
            with open(test_output_path, 'wb') as f:
                pickle.dump(results_test, f)

        except Exception as e:
            continue

        torch.cuda.empty_cache()

    return selected_contexts, all_agent_embeddings_reduced, dist_scores