import os
import argparse
import numpy as np
import joblib
import umap
import torch
import pickle
import json  
import seaborn as sns
from code.utils import get_embedding, get_embedding_llm, get_embedding_eedi
from code.agents import init_agent, init_agent_wikiarts
from code.utils import compute_population_dist, compute_dist_score
import numba
import random
import math
from sklearn.decomposition import PCA
import pandas as pd
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import matplotlib.pyplot as plt
plt.rc('text', usetex=True)
plt.rc('font', family='sans-serif')
plt.rcParams.update({'font.size': 24})
from sklearn.cluster import DBSCAN 
from scipy.stats import sem 
from code.visualize.configs import *
from code.visualize.plots import *
# set seeds
random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)


if __name__ == "__main__":
    # argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, required=True, help="Path to the data directory")
    parser.add_argument("--output_dir", type=str, nargs='+', required=True, 
                      help="Path(s) to the output directory. Multiple paths can be provided for aggregating results across seeds.")
    parser.add_argument("--model", type=str, help="Name of the model")
    parser.add_argument("--dataset_split", type=str, default="train", choices=["train", "test"], 
                      help="Dataset split to use: 'train' or 'test'")
    parser.add_argument("--domain", type=str, required=True,
                        help="Domain to use: 'opinionqa' or 'wikiarts'")
    parser.add_argument('--distance', type=str, default='euclidean', required=True,
                   help='Distance metric to use')
    parser.add_argument('--umap_n_neighbors', type=int, default=4,
                   help='Number of neighbors to consider in UMAP (default: 4)')
    parser.add_argument('--umap_min_dist', type=float, default=0.0,
                   help='Minimum distance between points in UMAP (default: 0.0)')
    parser.add_argument('--dbscan_eps', type=float, default=1.0,
                        help='DBSCAN epsilon parameter (default: 0.5)')
    parser.add_argument('--dbscan_min_samples', type=int, default=5,
                        help='DBSCAN min_samples parameter (default: 5)')
    parser.add_argument('--use_dbscan', action='store_true', default=False,
                        help='Flag to enable DBSCAN clustering for human embeddings in individual plots.')
    parser.add_argument('--arrow_iterations', type=int, nargs='+', default=None, 
                      help='List of 1-indexed agent iteration numbers to draw arrows for. If None, no arrows are drawn.')
    parser.add_argument('--methods_to_plot', type=str, nargs='+', required=True,
                      help='List of agent types to include in plots.')
    args = parser.parse_args()

    

    # Check if multiple output directories provided
    multi_seed_mode = len(args.output_dir) > 1
    
    # Create an aggregated results directory if in multi-seed mode
    aggregate_output_dir = None
    if multi_seed_mode:
        # Create a parent directory for aggregated results
        base_dir = os.path.dirname(args.output_dir[0])
        aggregate_output_dir = os.path.join(base_dir, "aggregated_results")
        os.makedirs(aggregate_output_dir, exist_ok=True)
        print(f"Multi-seed mode: Aggregating results from {len(args.output_dir)} experiments to {aggregate_output_dir}")
    else:
        # Single experiment mode - use the provided output directory
        aggregate_output_dir = args.output_dir[0]
    
    # If multi-seed mode, collect all dist scores across seeds for aggregation
    all_seeds_dist_scores = {}
    
    # Process each output directory (seed experiment)
    for exp_idx, output_dir in enumerate(args.output_dir):
        print(f"\n{'='*30}\nProcessing experiment directory {exp_idx+1}/{len(args.output_dir)}: {output_dir}\n{'='*30}")
        
        if args.domain == 'eedi':
            # key,concept,concept_name,question,option_mapping,references,option_ordinal,CorrectAnswer
            selected_questions = pd.read_csv('data/eedi/selected_questions.csv')
            # convert key to string
            selected_questions['key'] = selected_questions['key'].astype(str)
            selected_questions['CorrectAnswer'] = selected_questions['CorrectAnswer'].astype(str)
        
        # model, tokenizer = init_agent(args.model, adapter_path=None)
        if args.domain == 'wikiarts':
            model = Gemma3ForConditionalGeneration.from_pretrained(
                    args.model, 
                    device_map="auto",
                    attn_implementation="eager",
                    torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
                ).eval()
            
            tokenizer = AutoProcessor.from_pretrained(args.model)

            reducer = joblib.load(os.path.join(output_dir, 'reducer_64d.joblib'))
        else:
            model, tokenizer = None, None
        # Load human embeddings based on dataset split
        human_embeddings_path = os.path.join(args.data_dir, f'human_embeddings_{args.dataset_split}')
        human_embeddings_dict = {}
        for file in os.listdir(human_embeddings_path):
            if file.endswith('.npy'):
                uid = file.split('.')[0]
                human_embeddings_dict[uid] = np.load(os.path.join(human_embeddings_path, file))

        user_ids = list(human_embeddings_dict.keys())  
        user_embeddings_np = np.stack(list(human_embeddings_dict.values()), axis=0)

        # Set UMAP parameters
        umap_n_neighbors = args.umap_n_neighbors  
        umap_min_dist = args.umap_min_dist  

        if args.domain == 'wikiarts':
            human_embeddings_reduced = reducer.transform(user_embeddings_np)
        else:
            human_embeddings_reduced = user_embeddings_np

        print(f" human embeddings: {human_embeddings_reduced.shape}")
        if np.isnan(human_embeddings_reduced).any():
            print(f"WARNING: human_embeddings_reduced contains NaN values. Number of NaNs: {np.isnan(human_embeddings_reduced).sum()}")
        else:
            print("human_embeddings_reduced does not contain NaN values.")
        print(f"Sample of human_embeddings_reduced (first 2): \n{human_embeddings_reduced[:2]}")

        if args.distance=='euclidean':
            umap_reducer_2d = umap.UMAP(
                n_components=2, 
                init='random',
                random_state=42,
                n_jobs=-1,
                metric=l2_dist,
                n_neighbors=umap_n_neighbors, min_dist=umap_min_dist
            )
        elif args.distance == 'manhattan':
            umap_reducer_2d = umap.UMAP(
                n_components=2, 
                init='random',
                random_state=42,
                n_jobs=-1,
                metric=l1_dist,
                n_neighbors=umap_n_neighbors, min_dist=umap_min_dist
            )
            
        umap_reducer_2d.fit(human_embeddings_reduced, force_all_finite='allow-nan')

        human_embeddings_2d = umap_reducer_2d.transform(human_embeddings_reduced, force_all_finite='allow-nan')

        num_humans = human_embeddings_reduced.shape[0]
        embedding_dim = human_embeddings_reduced.shape[1]

        if args.distance == 'euclidean':
            normalized_term = num_humans * math.sqrt(embedding_dim)
        elif args.distance == 'manhattan':
            normalized_term = num_humans * embedding_dim
        # output_dir

        iterations_dir = os.path.join(output_dir, 'iterations')
        all_agent_responses = {agent_type: {} for agent_type in args.methods_to_plot}
        all_agent_embeddings_reduced = {agent_type: [] for agent_type in args.methods_to_plot}
        all_agent_embeddings_original = {agent_type: [] for agent_type in args.methods_to_plot}  
        dist_scores = {}
        # greeedy_z_coordinates = []
        for agent_type in args.methods_to_plot:
            all_agent_responses[agent_type] = {}

        iteration_folders = sorted(
        [d for d in os.listdir(iterations_dir) if os.path.isdir(os.path.join(iterations_dir, d)) and d.startswith('iter_')],
        key=lambda x: int(x.split('_')[1]) # Sort numerically by iteration number
        )

        last_iter_folder = iteration_folders[-1]
        REPRESENTATIVE_ASSIGNMENTS_FILE_SMALL = os.path.join(output_dir, 'iterations', last_iter_folder, 'representative_assignments.json')
        REPRESENTATIVE_ASSIGNMENTS_FILE_ADAPTIVE = os.path.join(output_dir, 'representative_assignments_two_stage_online_adaptive.json')


        two_stage_small_representative_idxs = load_representative_idxs(REPRESENTATIVE_ASSIGNMENTS_FILE_SMALL)
        print(f"Loaded {len(two_stage_small_representative_idxs)} Two-Stage Small Representative IDs:", two_stage_small_representative_idxs)

        two_stage_adaptive_representative_idxs = load_representative_idxs(REPRESENTATIVE_ASSIGNMENTS_FILE_ADAPTIVE)
        print(f"Loaded {len(two_stage_adaptive_representative_idxs)} Two-Stage Adaptive Representative IDs:", two_stage_adaptive_representative_idxs)


        # --- Load selected contexts for reppopdemo ---
        selected_contexts_reppopdemo_path = os.path.join(output_dir, 'selected_contexts_reppopdemo.json')
        greedy_selected_human_ids_per_iter = {}


        # only if train set
        if args.dataset_split == 'train':
            try:
                with open(selected_contexts_reppopdemo_path, 'r') as f:
                    selected_contexts_greedy = json.load(f)
                for iter_key, data in selected_contexts_greedy.items():
                    iter_num = int(iter_key.split('_')[1])
                    if 'selected_examples' in data:
                        # Extract only the numeric part of the user ID
                        ids = [ex['user_id'] for ex in data['selected_examples']]
                        greedy_selected_human_ids_per_iter[iter_num] = ids
                print(f"Loaded selected human IDs for reppopdemo from {selected_contexts_reppopdemo_path}")
                # print("Greedy selected IDs per iter:", greedy_selected_human_ids_per_iter) # Debug print
            except FileNotFoundError:
                print(f"Warning: {selected_contexts_reppopdemo_path} not found. Cannot connect dots for greedy agent.")
            except Exception as e:
                print(f"Warning: Error loading or parsing {selected_contexts_reppopdemo_path}: {e}")
            # --- End loading selected contexts ---


        for iter_folder in iteration_folders:
            iter_num = int(iter_folder.split('_')[1])
            iter_path = os.path.join(iterations_dir, iter_folder)
            print(f"Loading data for iteration {iter_num} from {iter_path}")

            kmedoids_all_responses = {}
            kmedoids_all_qa_pairs_dict = {}
            kmedoids_all_agent_embeddings_reduced = []
            
            # Load agent responses based on configured agents
            for agent_type in args.methods_to_plot:
                if agent_type == 'kmedoids':
                    for i in range(1, iter_num+1):
                        responses_path = os.path.join(output_dir, 'iterations', f'iter_{i}', 'kmedoids', f'{i}_{args.dataset_split}_responses.pkl')
                        with open(responses_path, 'rb') as f:
                            kmedoids_all_responses[i] = pickle.load(f)
                agent_responses_path = os.path.join(iter_path, f'{agent_type}_{args.dataset_split}_responses.pkl')
                if os.path.exists(agent_responses_path):
                    with open(agent_responses_path, 'rb') as f:
                        all_agent_responses[agent_type][iter_num] = pickle.load(f)
                        print(f"  Loaded {agent_type}_{args.dataset_split}_responses.pkl for iter {iter_num}")
                else:
                    print(f"  Warning: {agent_responses_path} not found.")
            
            
            # Process all agent responses
            qa_pairs_dict = {agent_type: [] for agent_type in args.methods_to_plot}
            # Process responses for each agent type
            for agent_type in args.methods_to_plot:
                if agent_type == 'kmedoids':
                    for i in range(1, iter_num+1):
                        if i not in kmedoids_all_qa_pairs_dict:
                            kmedoids_all_qa_pairs_dict[i] = []
                        if args.domain == 'wikiarts':
                            # load qa pairs for this iter
                            for qid, question, answer, emotion, explanation in kmedoids_all_responses[i]:
                                kmedoids_all_qa_pairs_dict[i].append((qid, question, answer, emotion, explanation))
                        else:
                            for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in kmedoids_all_responses[i]:
                                try:
                                    answer = '{"' + answer_key + '":"' + answer_text + '"}'
                                except Exception as e:
                                    answer = f''
                                kmedoids_all_qa_pairs_dict[i].append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val)) 
                if iter_num in all_agent_responses[agent_type]:
                    if args.domain == 'wikiarts':
                        for qid, question, answer, emotion, explanation in all_agent_responses[agent_type][iter_num]:
                            qa_pairs_dict[agent_type].append((qid, question, answer, emotion, explanation))
                    else:
                        for qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val in all_agent_responses[agent_type][iter_num]:
                            try:
                                answer = '{"' + answer_key + '":"' + answer_text + '"}'
                            except Exception as e:
                                answer = f''
                            qa_pairs_dict[agent_type].append((qid, question, answer, answer_key, answer_ordinal, max_val, min_val))
            

            print("Processing embeddings for iteration", iter_num)
            print(f"--- Iteration {iter_num}: Original Agent Embeddings Check ---")

            agent_embeddings_reduced = {}
            for agent_type, qa_pairs in qa_pairs_dict.items():
                if qa_pairs:  # Only process if we have responses for this agent in this iteration
                    if args.domain == 'wikiarts':
                        data_dir=os.path.join(output_dir, args.dataset_split)
                        agent_emb = get_embedding_llm(model, tokenizer, qa_pairs, data_dir)
                        agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
                    elif args.domain == 'eedi':
                        agent_emb = get_embedding_eedi(model, tokenizer, qa_pairs, selected_questions)
                        agent_emb_reduced = agent_emb.reshape(1, -1)
                    else:
                        agent_emb = get_embedding(model, tokenizer, qa_pairs)
                        agent_emb_reduced = agent_emb.reshape(1, -1)
                    all_agent_embeddings_original[agent_type].append(agent_emb)
                    all_agent_embeddings_reduced[agent_type].append(agent_emb_reduced)

                    agent_embeddings_reduced[agent_type] = agent_emb_reduced
                    print(f"{agent_type} agent_emb_reduced shape:", agent_emb_reduced.shape)
                    if agent_emb is not None:
                        print(f"  {agent_type} original agent_emb shape: {agent_emb.shape}, NaNs: {np.isnan(agent_emb).sum()}")
                        if iter_num == 1 and len(all_agent_embeddings_original[agent_type]) == 1: # Print sample only for the first one
                            print(f"    Sample of {agent_type} original_emb (first 5 elements): {agent_emb.flatten()[:5]}")

            if 'kmedoids' in args.methods_to_plot:
                for i in range(1, iter_num+1):
                    if args.domain == 'wikiarts':
                        agent_emb = get_embedding_llm(model, tokenizer, kmedoids_all_qa_pairs_dict[i], data_dir)
                        agent_emb_reduced = reducer.transform(agent_emb.reshape(1, -1))
                    elif args.domain == 'eedi':
                        agent_emb = get_embedding_eedi(model, tokenizer, kmedoids_all_qa_pairs_dict[i], selected_questions)
                        agent_emb_reduced = agent_emb.reshape(1, -1)
                    else:
                        agent_emb = get_embedding(model, tokenizer, kmedoids_all_qa_pairs_dict[i])
                        agent_emb_reduced = agent_emb.reshape(1, -1)
                    kmedoids_all_agent_embeddings_reduced.append(agent_emb_reduced)

            # Compute dist scores for each agent type
            if not dist_scores:
                dist_scores = {agent_type: [] for agent_type in args.methods_to_plot}  # Initialize only for agents with data

            if args.dataset_split == 'train':
                for agent_type in args.methods_to_plot:  
                    if 'agent' in agent_type:
                        selected_contexts_file = os.path.join(output_dir, f'selected_contexts_{agent_type.split("_")[0]}.json')
                    else:
                        selected_contexts_file = os.path.join(output_dir, f'selected_contexts_{agent_type}.json')
                    if os.path.exists(selected_contexts_file):
                        try:
                            with open(selected_contexts_file, 'r') as f:
                                selected_contexts = json.load(f)
                            
                            if iter_folder in selected_contexts:
                                current_dist = selected_contexts[iter_folder]['population_dist'] / normalized_term
                                dist_scores[agent_type].append(current_dist)
                            else:

                                print(f"  Warning: Iteration {iter_folder} not found in {selected_contexts_file}. Skipping dist score.")
                        except Exception as e:
                            print(f"  Warning: Error reading or parsing {selected_contexts_file} for iter {iter_folder}: {e}")
                    else:
                        print(f"  Warning: {selected_contexts_file} not found. Cannot get dist score for {agent_type} in iter {iter_folder}.")

            elif args.dataset_split == 'test':
                if not dist_scores:
                    dist_scores = {agent_type: [] for agent_type in args.methods_to_plot}

                for agent_type in args.methods_to_plot:
                    if agent_type == 'kmedoids':
                        stacked_agent_embeddings = np.vstack(kmedoids_all_agent_embeddings_reduced)
                        print(f"size of np.vstack(kmedoids_all_agent_embeddings_reduced)", stacked_agent_embeddings.shape)
                        if human_embeddings_reduced.shape[1] == stacked_agent_embeddings.shape[1]:
                            current_dist = compute_population_dist(
                                human_embeddings_reduced,
                                stacked_agent_embeddings,  
                                args.distance
                            ) / normalized_term
                            print(f"Current dist for {agent_type} at iteration {iter_num}: {current_dist}")
                            dist_scores[agent_type].append(current_dist)
                        else:
                            print(f"  Warning: Shape mismatch for {agent_type}. Human: {human_embeddings_reduced.shape[1]}, Agent: {stacked_agent_embeddings.shape[1]}. Skipping dist.")
                            
                    if all_agent_embeddings_reduced[agent_type]:  
                        print("size of human_embeddings_reduced:", human_embeddings_reduced.shape)
                        stacked_agent_embeddings = np.vstack(all_agent_embeddings_reduced[agent_type])
                        print(f"size of np.vstack(all_agent_embeddings_reduced[{agent_type}])", stacked_agent_embeddings.shape)

                        if human_embeddings_reduced.shape[1] == stacked_agent_embeddings.shape[1]:
                            current_dist = compute_population_dist(
                                human_embeddings_reduced,
                                stacked_agent_embeddings,  
                                args.distance
                            ) / normalized_term
                            print(f"Current dist for {agent_type} at iteration {iter_num}: {current_dist}")
                            if agent_type in agent_embeddings_reduced:  
                                dist_scores[agent_type].append(current_dist)

                        else:
                            print(f"  Warning: Shape mismatch for {agent_type}. Human: {human_embeddings_reduced.shape[1]}, Agent: {stacked_agent_embeddings.shape[1]}. Skipping dist.")

        # Transform original embeddings to 2D
        all_agent_embeddings_2d = {}
        for agent_type, embeddings_list in all_agent_embeddings_reduced.items():
            if embeddings_list:
                stacked_embeddings = np.vstack([emb.reshape(1, -1) for emb in embeddings_list])
                embeddings_2d = umap_reducer_2d.transform(stacked_embeddings, force_all_finite='allow-nan')
                all_agent_embeddings_2d[agent_type] = [embeddings_2d[i:i+1, :] for i in range(len(embeddings_2d))]

        #kmedoids
        if 'kmedoids' in args.methods_to_plot:
            kmedoids_stacked_embeddings = np.vstack([emb.reshape(1, -1) for emb in kmedoids_all_agent_embeddings_reduced])
            kmedoids_embeddings_2d = umap_reducer_2d.transform(kmedoids_stacked_embeddings, force_all_finite='allow-nan')
            all_agent_embeddings_2d['kmedoids'] = [kmedoids_embeddings_2d[i:i+1, :] for i in range(len(kmedoids_embeddings_2d))]
                
        if multi_seed_mode:
            for agent_type, scores in dist_scores.items():
                if agent_type not in all_seeds_dist_scores:
                    all_seeds_dist_scores[agent_type] = []
                all_seeds_dist_scores[agent_type].append(scores)
        
        # For each seed's directory, generate individual plots
        plot_dist_progress(args, dist_scores, output_dir, args.dataset_split)
        
        # Create individual plots for each agent type in the current experiment directory
        for agent_type, embeddings in all_agent_embeddings_2d.items():
            if agent_type not in args.methods_to_plot:
                continue
                
            # Determine which representative indices to use
            current_representative_idxs = []
            if agent_type == 'two_stage_pe_small':
                current_representative_idxs = two_stage_small_representative_idxs
            elif agent_type == 'two_stage_pe_online_adaptive':
                current_representative_idxs = two_stage_adaptive_representative_idxs

            selected_ids_for_plot = None
            if agent_type == 'reppopdemo':
                selected_ids_for_plot = greedy_selected_human_ids_per_iter
            
            current_agent_reduced_embs_for_clustering = all_agent_embeddings_reduced.get(agent_type, [])
            
            plot_individual_agent_2d(
                args,
                human_embeddings_2d,
                embeddings,
                user_ids,
                current_representative_idxs,
                agent_type,
                output_dir,  
                args.dataset_split,
                args.use_dbscan,
                args.dbscan_eps,
                args.dbscan_min_samples,
                human_embeddings_for_clustering=human_embeddings_reduced,
                current_agent_clustering_dim_embeddings_list=current_agent_reduced_embs_for_clustering,
                dist_for_clustering=args.distance,
                selected_human_ids_per_iter=selected_ids_for_plot,
                arrow_iterations=args.arrow_iterations
            )
    
    if multi_seed_mode:
        print("\n\nCalculating aggregated statistics across all seeds...")
        aggregate_data = {}

        print("all_seeds_dist_scores", all_seeds_dist_scores)
        
        for agent_type, all_seeds_scores in all_seeds_dist_scores.items():
            max_len = max(len(scores) for scores in all_seeds_scores)
            padded_scores = []
            
            for scores in all_seeds_scores:
                if len(scores) < max_len:
                    padded_scores.append(scores + [float('nan')] * (max_len - len(scores)))
                else:
                    padded_scores.append(scores)
            
            scores_array = np.array(padded_scores)

            print("scores_array", scores_array)
            
            mean_scores = np.nanmean(scores_array, axis=0)
            print("mean_scores", mean_scores)
            stderr_scores = sem(scores_array, axis=0, nan_policy='omit')
            print("stderr_scores", stderr_scores)
            
            aggregate_data[agent_type] = {
                'mean': mean_scores,
                'stderr': stderr_scores
            }
            
            print(f"Agent {agent_type}: Mean scores shape {mean_scores.shape}, StdErr shape {stderr_scores.shape}")
        print("aggregate_data", aggregate_data)
        
        reference_dist_scores = {}
        for agent_type, data in aggregate_data.items():
            reference_dist_scores[agent_type] = data['mean'].tolist()
        
        plot_dist_progress(args,reference_dist_scores, aggregate_output_dir, args.dataset_split, aggregate_data)
        
        with open(os.path.join(aggregate_output_dir, f'aggregated_dist_scores_{args.dataset_split}.pkl'), 'wb') as f:
            pickle.dump({
                'aggregate_data': aggregate_data,
                'all_seeds_data': all_seeds_dist_scores
            }, f)
        
        print(f"Completed aggregated analysis across {len(args.output_dir)} seeds.")

