from code.prompts import *
import argparse
import numpy as np
import os
import torch
import pandas as pd
from typing import Dict, List, Tuple, Any
from pathlib import Path
from code.utils import compute_dist_score, compute_population_dist
from code.agents import init_agent, init_agent_wikiarts, get_agent_responses, get_agent_responses_wikiarts
import json
import time
import joblib
import pickle
from code.utils import get_embedding, get_embedding_llm, get_embedding_eedi
from tqdm import tqdm
from itertools import combinations
import random
import math
import copy

def run_random(args: argparse.Namespace,
                        human_embeddings: np.ndarray,
                        train_questions: List,
                        test_questions: List,
                        pd_questions: pd.DataFrame,
                        user_responses_train: Dict,
                        model,
                        tokenizer,
                        reducer,
                        domain=None) -> Tuple[Dict, List, List]:
    """
    Implements a random algorithm that selects a random human and uses all their examples as context.
    """

    #set seed
    random.seed(args.seed)


    all_agent_embeddings_reduced = []
    dist_scores = []
    selected_contexts = {}
    
    # Save contexts to json file with algorithm name
    contexts_file = os.path.join(args.output_dir, f'selected_contexts_random.json')
    
    for t in range(args.n_agents):
        print(f"\nIteration {t + 1}/{args.n_agents}")
        iter_start_time = time.time()

        # --- Build context for this iteration ---
        context = ""
        context_list = []
        valid_examples = []
        num_examples_added = 0 

        sampled_questions = random.sample(list(train_questions), args.k_examples)

        for question in sampled_questions:
            valid_responses_for_question = []
            for user_id in user_responses_train:
                for response in user_responses_train[user_id]:
                    if response[0] == question and response[2] != "":
                        valid_responses_for_question.append((user_id, response))
                        break

            if valid_responses_for_question:
                selected_user_id, selected_response = random.choice(valid_responses_for_question)
                selected_user_id, selected_response = random.choice(valid_responses_for_question)
                if domain == 'wikiarts':
                    image_path = selected_response[1]
                    answer = selected_response[2]
                    context_list.append({"type": "text", "text": "Painting:\n"})
                    context_list.append({"type": "image", "image": image_path})
                    context_list.append({"type": "text", "text": f"Response: {answer}\n"})
                    
                else:
                    context += f"Question: {selected_response[1]}\nAnswer: {selected_response[2]}\n\n"
                valid_examples.append((f"user_{selected_user_id}", selected_response[0], selected_response))
                num_examples_added += 1
        
        if domain == 'wikiarts':
            final_context = context_list.copy()
            final_context.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
        else:
            if args.role_play:
                if args.domain == 'opinionqa':
                    final_context = context + OPIONQA_INSTRUCTION
                elif args.domain == 'eedi':
                    final_context = context + EEDI_INSTRUCTION
                elif args.domain == 'wikiarts':
                    final_context = context + WIKIARTS_INSTRUCTION
            else:
                final_context = context + "Answer the following question.\n\n"
        
        try:
            if domain == 'wikiarts':
                results = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, final_context)
                qa_pairs = []
                for qid, question, answer, emotion, explanation in results:
                    qa_pairs.append((qid, question, answer, emotion, explanation))
                data_dir=os.path.join(args.output_dir, "train")
                agent_emb = get_embedding_llm(model, tokenizer, qa_pairs, data_dir)
                agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
                # print("agent_emb_reduced", agent_emb_reduced.shape)
            else:
                results = get_agent_responses(model, tokenizer, train_questions, pd_questions, final_context)
                
                # Get embeddings
                qa_pairs = []
                for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in results:
                    try:
                        answer = '{"' + answer_key + '":"' + answer_text + '"}' 
                    except Exception as e:
                        answer = f''
                    qa_pairs.append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val))
                
                # Get agent embedding
                if args.domain == 'opinionqa':
                    agent_emb = get_embedding(model, tokenizer, qa_pairs)
                elif args.domain == 'eedi':
                    agent_emb = get_embedding_eedi(model, tokenizer, qa_pairs, pd_questions)
                agent_emb_reduced = agent_emb

            # Add agent embedding and calculate dist
            all_agent_embeddings_reduced.append(agent_emb_reduced)
            current_dist = compute_population_dist(
                human_embeddings,
                np.vstack(all_agent_embeddings_reduced),
                args.distance
            )
            dist_scores.append(current_dist)
            
            # Save the responses for this context
            train_responses_output_path = f'{args.output_dir}/iterations/iter_{t+1}/random_train_responses.pkl'
            os.makedirs(os.path.dirname(train_responses_output_path), exist_ok=True)
            with open(train_responses_output_path, 'wb') as f:
                pickle.dump(results, f)
            
            # Save context information
            selected_contexts[f"iter_{t+1}"] = {
                "context": final_context,
                "population_dist": float(current_dist),
                "selected_examples": [{
                    "user_id": example[0],
                    "example_index": example[1],
                    "question_id": example[2][0],
                    "question": example[2][1],
                    "answer": example[2][2]
                } for example in valid_examples]
            }
            
            # Save selected contexts after each iteration
            with open(contexts_file, 'w') as f:
                json.dump(selected_contexts, f, indent=2)

            # Save runtime to logs
            iter_runtime = time.time() - iter_start_time
            logs_path = f'{args.output_dir}/iterations/iter_{t+1}/logs.json'
            os.makedirs(os.path.dirname(logs_path), exist_ok=True)
            
            logs = {}
            if os.path.exists(logs_path):
                with open(logs_path, 'r') as f:
                    logs = json.load(f)
            
            logs['runtime_random'] = iter_runtime
            with open(logs_path, 'w') as f:
                json.dump(logs, f, indent=2)

            # Get test set responses
            if domain == 'wikiarts':
                results_test = get_agent_responses_wikiarts(model, tokenizer, test_questions, pd_questions, final_context)
            else:
                results_test = get_agent_responses(model, tokenizer, test_questions, pd_questions, final_context)
            test_output_path = f'{args.output_dir}/iterations/iter_{t+1}/random_test_responses.pkl'
            with open(test_output_path, 'wb') as f:
                pickle.dump(results_test, f)
                
        except Exception as e:
            # print(f"Error in iteration {t+1}: {str(e)}")
            continue
            
        torch.cuda.empty_cache()
    
    return selected_contexts, all_agent_embeddings_reduced, dist_scores