from code.prompts import *
import argparse
import numpy as np
import os
import torch
import pandas as pd
from typing import Dict, List, Tuple, Any
from pathlib import Path
from code.utils import compute_dist_score, compute_population_dist
from code.agents import init_agent, init_agent_wikiarts, get_agent_responses, get_agent_responses_wikiarts
import json
import time
import joblib
import pickle
from code.utils import get_embedding, get_embedding_llm, get_embedding_eedi
from tqdm import tqdm
from itertools import combinations
import random
import math
import copy

def run_kmedoids(args: argparse.Namespace,
                                   human_embeddings: np.ndarray,
                                   train_questions: List,
                                   test_questions: List,
                                   pd_questions: pd.DataFrame,
                                   user_responses_train: Dict,
                                   model,
                                   tokenizer,
                                   reducer,
                                   user_ids: List[str],
                                   domain=None) -> Tuple[Dict, List, List]:
    # Set seed
    random.seed(args.seed)

    all_agent_embeddings_reduced = []
    dist_scores = []
    selected_contexts = {}

    # Save contexts to json file
    contexts_file = os.path.join(args.output_dir, f'selected_contexts_kmedoids.json')
    
    for t in range(1, args.n_agents + 1):
        iter_start_time = time.time()
        
        if t > len(human_embeddings):
            break
        
        from sklearn_extra.cluster import KMedoids
        
        n_samples = human_embeddings.shape[0]
        distance_matrix = np.zeros((n_samples, n_samples))
        
        for i in range(n_samples):
            for j in range(n_samples):
                if i == j:
                    distance_matrix[i, j] = 0.0  
                else:
                    distance = compute_dist_score(
                        human_embeddings[i], 
                        human_embeddings[j],
                        args.distance
                    )
                    distance_matrix[i, j] = distance
        
        kmedoids = KMedoids(n_clusters=t, random_state=args.seed, metric='precomputed')
        cluster_labels = kmedoids.fit_predict(distance_matrix)
        medoid_indices = kmedoids.medoid_indices_
        
        del distance_matrix
        
        user_clusters = {user_ids[i]: cluster_labels[i] for i in range(len(user_ids))}
        
        clusters = {}
        for i, cluster_id in enumerate(cluster_labels):
            if cluster_id not in clusters:
                clusters[cluster_id] = []
            # print("user_ids[i]", user_ids[i])
            clusters[cluster_id].append(user_ids[i])
        
        cluster_sizes = [len(users) for users in clusters.values()]
        # print(f"Created {len(clusters)} clusters with sizes: {cluster_sizes}")
        
        iteration_agent_embeddings = []
        
        for cluster_idx, cluster_id in enumerate(sorted(clusters.keys())):
            cluster_start_time = time.time()
            # print(f"\n--- Processing Cluster {cluster_id+1}/{t} for Iteration {t} ---")
            
            torch.cuda.empty_cache()
            
            cluster_users = clusters[cluster_id]
            if not cluster_users:
                # print(f"Warning: Cluster {cluster_id+1} has no users. Skipping.")
                continue
            
            cluster_responses = []
            for user_id_str in cluster_users:
                user_id = user_id_str.replace('user_', '')
                if user_id not in user_responses_train:
                    user_id = int(user_id_str.replace('user_', ''))
                
                user_responses = user_responses_train[user_id]
                for i, response in enumerate(user_responses):
                    if response[2] != "":  # Valid response
                        cluster_responses.append((user_id_str, i, response))
            
            if len(cluster_responses) < args.k_examples:
                k_to_use = len(cluster_responses)
            else:
                k_to_use = args.k_examples
            
            if k_to_use == 0:
                continue
                
            sampled_responses = random.sample(cluster_responses, k_to_use)
            
            context = ""
            context_list = []
            if domain == 'wikiarts':
                for user_id_str, original_index, response in sampled_responses:
                    image_path = response[1]
                    answer = response[2]
                    context_list.append({"type": "text", "text": "Painting:\n"})
                    context_list.append({"type": "image", "image": image_path})
                    context_list.append({"type": "text", "text": f"Response: {answer}\n"})
                
                final_context_list = context_list.copy()
                final_context_list.append({"type": "text", "text": WIKIARTS_INSTRUCTION})
                final_context = final_context_list
            else:
                for user_id_str, original_index, response in sampled_responses:
                    context += f"Question: {response[1]}\nAnswer: {response[2]}\n\n"
                    
                if args.role_play:
                    if args.domain == 'opinionqa':
                        final_context = context + OPIONQA_INSTRUCTION
                    elif args.domain == 'eedi':
                        final_context = context + EEDI_INSTRUCTION
                    elif args.domain == 'wikiarts':
                        final_context = context + WIKIARTS_INSTRUCTION
                else:
                    final_context = context + "Answer the following question.\n\n"
            
            try:
                if domain == 'wikiarts':
                    results = get_agent_responses_wikiarts(model, tokenizer, train_questions, pd_questions, final_context)
                    qa_pairs = []
                    for qid, question, answer, emotion, explanation in results:
                        qa_pairs.append((qid, question, answer, emotion, explanation))
                    data_dir = os.path.join(args.output_dir, "train")
                    agent_emb = get_embedding_llm(model, tokenizer, qa_pairs, data_dir)
                    agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
                else:
                    results = get_agent_responses(model, tokenizer, train_questions, pd_questions, final_context)
                    qa_pairs = []
                    for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in results:
                        try:
                            answer = '{"' + answer_key + '":"' + answer_text + '"}'
                        except Exception:
                            answer = ''
                        qa_pairs.append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val))
                    
                    if args.domain == 'opinionqa':
                        agent_emb_reduced = get_embedding(model, tokenizer, qa_pairs)
                    elif args.domain == 'eedi':
                        agent_emb_reduced = get_embedding_eedi(model, tokenizer, qa_pairs, pd_questions)
                
                iteration_agent_embeddings.append(agent_emb_reduced)
                
                medoid_info = {}
                if cluster_idx < len(medoid_indices):
                    medoid_idx = medoid_indices[cluster_idx]
                    medoid_user_id = user_ids[medoid_idx]
                    medoid_info = {
                        "medoid_user_id": medoid_user_id,
                        "is_medoid_user_in_cluster": medoid_user_id in cluster_users
                    }
                
                iter_dir = f"iter_{t}"
                kmedoids_dir = f"{args.output_dir}/iterations/{iter_dir}/kmedoids"
                os.makedirs(kmedoids_dir, exist_ok=True)
                
                cluster_id_plus_1 = cluster_id + 1
                train_responses_filename = f"{cluster_id_plus_1}_train_responses.pkl"
                test_responses_filename = f"{cluster_id_plus_1}_test_responses.pkl"
                
                train_responses_output_path = f'{kmedoids_dir}/{train_responses_filename}'
                with open(train_responses_output_path, 'wb') as f:
                    pickle.dump(results, f)
                
                cluster_dir = f"iter_{t}_cluster_{cluster_id+1}"
                
                selected_contexts[cluster_dir] = {
                    "context": final_context,
                    "iteration": t,
                    "cluster_id": int(cluster_id),
                    "cluster_size": len(cluster_users),
                    **medoid_info,
                    "selected_examples": [{
                        "user_id": example[0],
                        "example_index": example[1],
                        "question_id": example[2][0],
                        "question": example[2][1],
                        "answer": example[2][2]
                    } for example in sampled_responses]
                }
                
                try:
                    torch.cuda.empty_cache()
                    
                    if domain == 'wikiarts':
                        results_test = get_agent_responses_wikiarts(model, tokenizer, test_questions, pd_questions, final_context)
                    else:
                        results_test = get_agent_responses(model, tokenizer, test_questions, pd_questions, final_context)
                    test_output_path = f'{kmedoids_dir}/{test_responses_filename}'
                    with open(test_output_path, 'wb') as f:
                        pickle.dump(results_test, f)
                    
                    del results_test
                except Exception as e:
                    print(f"Error generating test responses for cluster {cluster_id+1}: {str(e)}")
                
                cluster_runtime = time.time() - cluster_start_time
            except Exception as e:
                continue
            
            torch.cuda.empty_cache()
        
        if iteration_agent_embeddings:
            all_agent_embeddings_reduced = iteration_agent_embeddings
            
            emb_stack = np.vstack(all_agent_embeddings_reduced)
            current_dist = compute_population_dist(
                human_embeddings,
                emb_stack,
                args.distance
            )
            dist_scores.append(current_dist)
            
            iter_dir = f"iter_{t}"
            selected_contexts[iter_dir] = {
                "iteration": t,
                "num_clusters": t,
                "num_agents_created": len(iteration_agent_embeddings),
                "population_dist": float(current_dist),
                "cluster_sizes": cluster_sizes
            }
        else:
            dist_scores.append(float('nan'))
            
            iter_dir = f"iter_{t}"
            selected_contexts[iter_dir] = {
                "iteration": t,
                "num_clusters": t,
                "error": "No valid agents created",
                "population_dist": float('nan')
            }
        
        with open(contexts_file, 'w') as f:
            json.dump(selected_contexts, f, indent=2)
        
        iter_runtime = time.time() - iter_start_time
        iter_dir = f"iter_{t}"
        logs_path = f'{args.output_dir}/iterations/{iter_dir}/logs.json'
        os.makedirs(os.path.dirname(logs_path), exist_ok=True)

        # Save logs
        if os.path.exists(logs_path):
            # load logs
            with open(logs_path, 'r') as f:
                logs = json.load(f)
        else:
            logs = {}
        logs['runtime_kmedoids'] = iter_runtime
        logs['num_clusters'] = t
        logs['num_agents'] = len(iteration_agent_embeddings) if iteration_agent_embeddings else 0
        with open(logs_path, 'w') as f:
            json.dump(logs, f, indent=2)
    
    return selected_contexts, all_agent_embeddings_reduced, dist_scores