import os
import argparse
import numpy as np
import joblib
import umap
import torch
import pickle
import json  
import seaborn as sns
from code.utils import get_embedding, get_embedding_llm, get_embedding_eedi
from code.agents import init_agent, init_agent_wikiarts
from code.utils import compute_population_dist, compute_dist_score
import numba
import random
import math
from sklearn.decomposition import PCA
import pandas as pd
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import matplotlib.pyplot as plt
plt.rc('text', usetex=True)
plt.rc('font', family='sans-serif')
plt.rcParams.update({'font.size': 24})
from sklearn.cluster import DBSCAN 
from scipy.stats import sem 
from code.visualize.configs import *


def load_representative_idxs(assignments_file):
    try:
        with open(assignments_file, 'r') as f:
            assignments = json.load(f)
        idxs = [assignment["rep"].split('_')[1] for assignment in assignments]
        return idxs
    except FileNotFoundError:
        print(f"Warning: Representative assignments file not found: {assignments_file}")
        return []
    except Exception as e:
        print(f"Warning: Could not load or parse representative assignments: {e}")
        return []

@numba.njit()
def l2_dist(a,b):
    valid = ~(np.isnan(a) | np.isnan(b))
    total_coords = a.shape[0]
    valid_coords = np.sum(valid)
    if valid_coords == 0:
        print("Warning: No valid coordinates found for distance calculation.")
        return np.inf
    weight = total_coords / valid_coords
    squared_diff = (a[valid] - b[valid]) ** 2
    return np.sqrt(weight * np.sum(squared_diff))

@numba.njit()
def l1_dist(a,b):
    valid = ~(np.isnan(a) | np.isnan(b))    
    total_coords = a.shape[0]
    valid_coords = np.sum(valid)
    if valid_coords == 0:
        print("Warning: No valid coordinates found for distance calculation.")
        return np.inf
    weight = total_coords / valid_coords
    abs_diff = np.abs(a[valid] - b[valid])
    return weight * np.sum(abs_diff)

def plot_dist_progress(args, dist_scores, output_dir, dataset_split, aggregate_data=None):
    sns.set_style("whitegrid")
    sns.set_context("paper", font_scale=1.2)
    
    plt.figure(figsize=(6, 5))
    
    styles = {}
    for agent_type in args.methods_to_plot:
        if agent_type in AGENT_PLOT_STYLES:
            styles[agent_type] = {
                'color': AGENT_PLOT_STYLES[agent_type]['color'],
                'marker': AGENT_PLOT_STYLES[agent_type]['marker'],
                'label': AGENT_DISPLAY_NAMES.get(agent_type, agent_type)
            }
    
    for agent_type, scores in dist_scores.items():
        if agent_type not in args.methods_to_plot:
            continue
            
        if 'robust' in agent_type:
            m_size = 12
        else:
            m_size = 10
        
        if aggregate_data and agent_type in aggregate_data:
            means = aggregate_data[agent_type]['mean']
            errors = aggregate_data[agent_type]['stderr']
            
            sns.lineplot(
                x=range(1, len(means) + 1),
                y=means,
                color=styles[agent_type]['color'],
                label=styles[agent_type]['label'],
                linewidth=3.5,
                zorder=AGENT_PLOT_STYLES[agent_type]['zorder']
            )
            
            plt.errorbar(
                range(1, len(means) + 1),
                means,
                yerr=errors,
                fmt='none',
                ecolor=styles[agent_type]['color'],
                alpha=0.8,
                capsize=0,
                elinewidth=3,
                zorder=AGENT_PLOT_STYLES[agent_type]['zorder'] - 0.5
            )
        elif scores:
            sns.lineplot(
                x=range(1, len(scores) + 1),
                y=scores,
                color=styles[agent_type]['color'],
                marker=styles[agent_type]['marker'],
                label=styles[agent_type]['label'],
                linewidth=3,
                markersize=m_size,
                zorder=AGENT_PLOT_STYLES[agent_type]['zorder']
            )
    
    plt.xlabel('Size of Agent Set \(M\)', fontsize=22)

    if 'euclidean' in args.distance:
        plt.ylabel('Average RMSE', fontsize=22)
    else:
        plt.ylabel('Average MAE', fontsize=22)
    
    plt.grid(True, linestyle='--', alpha=0.7, zorder=1)
    
    handles, labels = plt.gca().get_legend_handles_labels()
    
    filtered_order = [label for label in AGENT_LEGEND_ORDER if label in labels]
    
    handle_label_dict = dict(zip(labels, handles))
    ordered_handles = [handle_label_dict[label] for label in filtered_order if label in handle_label_dict]
    ordered_labels = [label for label in filtered_order if label in handle_label_dict]
    plt.legend(ordered_handles, ordered_labels, bbox_to_anchor=(0.42, 1.2), loc='center', ncol=2, fontsize=16)
    
    x_ticks = range(1, len(next(iter(dist_scores.values() or [[]]), [])) + 1)
    plt.xticks(x_ticks)

    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)
    
    plt.tight_layout()
    
    suffix = "_aggregated" if aggregate_data else ""
    plot_path = os.path.join(output_dir, f'dist_progress_all_agents_{dataset_split}{suffix}.pdf')
    plt.savefig(plot_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Saved combined dist progress plot to {plot_path}")

def plot_individual_agent_2d(args,human_embeddings_2d, agent_embeddings_2d, user_ids, representative_idxs, agent_type, output_dir, dataset_split, use_dbscan, dbscan_eps, dbscan_min_samples, human_embeddings_for_clustering, current_agent_clustering_dim_embeddings_list, dist_for_clustering, selected_human_ids_per_iter=None, arrow_iterations=None):
    if agent_type not in args.methods_to_plot:
        return 
        
    sns.set_style("whitegrid")
    sns.set_context("paper", font_scale=1.2)

    plt.figure(figsize=(4, 4))

    num_total_humans = human_embeddings_2d.shape[0]
    human_colors_for_plot = ['gray'] * num_total_humans 
    
    assigned_iteration_index = None 

    num_agent_iterations = len(current_agent_clustering_dim_embeddings_list)
    assigned_iteration_index = np.full(num_total_humans, -1, dtype=int) 

    print(f"DEBUG: Processing test split for agent_type={agent_type}")
    print(f"DEBUG: Number of agent iterations (for clustering): {num_agent_iterations}")
    print(f"DEBUG: Human embeddings for clustering shape: {human_embeddings_for_clustering.shape}")
    

    dist_func = l2_dist if dist_for_clustering == 'euclidean' else l1_dist
    print(f"DEBUG: Using distance function: {dist_func.__name__} on embeddings of dim {human_embeddings_for_clustering.shape[1] if num_total_humans > 0 else 'N/A'}")
    
    for human_idx in range(num_total_humans):
        human_emb_dist = human_embeddings_for_clustering[human_idx]
        min_dist_to_current_agent_iter = np.inf
        closest_iter_for_human = -1
        for iter_idx, agent_iter_emb_dist in enumerate(current_agent_clustering_dim_embeddings_list):
            dist = dist_func(human_emb_dist, agent_iter_emb_dist.flatten())
            if dist < min_dist_to_current_agent_iter:
                min_dist_to_current_agent_iter = dist
                closest_iter_for_human = iter_idx
        assigned_iteration_index[human_idx] = closest_iter_for_human
        if human_idx < 5:  
            print(f"DEBUG: Human {human_idx} (using reduced emb for dist) assigned to iteration {closest_iter_for_human} with distance {min_dist_to_current_agent_iter}")

    iteration_palette = [] 
    agent_base_color = AGENT_PLOT_STYLES[agent_type]['color']
    print(f"DEBUG: Agent base color for palette: {agent_base_color}")

    if num_agent_iterations > 0:
        if num_agent_iterations == 1:
            iteration_palette = [agent_base_color]
        else: # num_agent_iterations > 1
            try:
                n_colors_requested = num_agent_iterations + 2 # Request a few extra to pick more distinct ones
                
                temp_palette_full = sns.light_palette(agent_base_color, n_colors=n_colors_requested, as_cmap=False)
                
                if len(temp_palette_full) >= num_agent_iterations:
                    if len(temp_palette_full) > num_agent_iterations and num_agent_iterations >= 1 :
                        start_idx = 1 
                        end_idx = start_idx + num_agent_iterations
                        if end_idx <= len(temp_palette_full): # Ensure slice is valid
                            iteration_palette = temp_palette_full[start_idx:end_idx]
                        else:
                            iteration_palette = sns.light_palette(agent_base_color, n_colors=num_agent_iterations, as_cmap=False)
                    else:
                            iteration_palette = sns.light_palette(agent_base_color, n_colors=num_agent_iterations, as_cmap=False)
                else:
                    iteration_palette = None

                if iteration_palette and len(iteration_palette) > 1:
                    unique_colors_in_palette = {tuple(c[:3]) for c in iteration_palette}
                    if len(unique_colors_in_palette) < max(2, num_agent_iterations // 2):
                        print(f"DEBUG: light_palette for '{agent_base_color}' produced non-distinct colors ({len(unique_colors_in_palette)} unique from {len(iteration_palette)}). Forcing fallback.")
                        iteration_palette = None 
            
            except Exception as e:
                print(f"DEBUG: Exception during light_palette generation for '{agent_base_color}': {e}. Using fallback palette.")
                iteration_palette = None 

            if not iteration_palette or len(iteration_palette) < num_agent_iterations:
                print(f"DEBUG: Falling back to a standard categorical palette for {num_agent_iterations} colors.")
                if num_agent_iterations <= 10:
                    iteration_palette = sns.color_palette("husl", n_colors=num_agent_iterations)
                else:
                    iteration_palette = sns.color_palette("viridis", n_colors=num_agent_iterations)
    
    if num_agent_iterations > 0 and (not iteration_palette or len(iteration_palette) < num_agent_iterations):
        print(f"DEBUG: Final palette safety net: Using viridis for {num_agent_iterations} colors because previous attempts failed.")
        iteration_palette = sns.color_palette("viridis", n_colors=num_agent_iterations)
    elif num_agent_iterations == 0:
        iteration_palette = []

    print(f"DEBUG: Final palette has {len(iteration_palette)} colors for {num_agent_iterations} iterations.")
    print(f"DEBUG: Iteration palette values: {iteration_palette}") 

    for human_idx in range(num_total_humans):
        iter_assigned = assigned_iteration_index[human_idx]
        if iter_assigned != -1 and iter_assigned < len(iteration_palette):
            human_colors_for_plot[human_idx] = iteration_palette[iter_assigned]
            
    print("DEBUG: assigned_iteration_index distribution:", np.bincount(assigned_iteration_index[assigned_iteration_index >= 0]))
    print("DEBUG: Number of unassigned humans:", np.sum(assigned_iteration_index == -1))
    if use_dbscan: 
        dbscan = DBSCAN(eps=dbscan_eps, min_samples=dbscan_min_samples)
        human_cluster_labels = dbscan.fit_predict(human_embeddings_2d)
        unique_labels = np.unique(human_cluster_labels)
        n_clusters_dbscan = len(unique_labels[unique_labels != -1])
        
        dbscan_palette_list = []
        if n_clusters_dbscan > 0:
            dbscan_palette_list = sns.color_palette("husl", n_clusters_dbscan)
        
        cluster_colors_dict_dbscan = {}
        color_idx_dbscan = 0
        for label in unique_labels:
            if label == -1:
                cluster_colors_dict_dbscan[label] = 'lightgray' 
            else:
                if color_idx_dbscan < len(dbscan_palette_list):
                    cluster_colors_dict_dbscan[label] = dbscan_palette_list[color_idx_dbscan]
                    color_idx_dbscan += 1
                else: 
                    cluster_colors_dict_dbscan[label] = 'black' 
        
        for i in range(len(human_cluster_labels)):
            human_colors_for_plot[i] = cluster_colors_dict_dbscan[human_cluster_labels[i]]


    original_user_ids = user_ids  
    user_ids_stripped = [uid.split('_')[-1] for uid in user_ids]
    user_ids_np = np.array(user_ids_stripped)

    user_id_to_index = {uid: i for i, uid in enumerate(user_ids_stripped)}

    if representative_idxs and user_ids:
        representative_idxs = [str(idx) for idx in representative_idxs]
    representative_idxs_set = set(representative_idxs)
    mask_rep = np.array([uid in representative_idxs_set for uid in user_ids_np])
    mask_non_rep = ~mask_rep

    plt.scatter([], [], color='gray', alpha=0.6, s=120, label='Humans')

    sns.scatterplot(
        x=human_embeddings_2d[mask_non_rep, 0],
        y=human_embeddings_2d[mask_non_rep, 1],
        color='gray',
        alpha=0.6, 
        s=120,
        legend=False, 
        zorder=1
    )

    if np.any(mask_rep) and ('two_stage' in agent_type or agent_type == 'all'):
        sns.scatterplot(
            x=human_embeddings_2d[mask_rep, 0],
            y=human_embeddings_2d[mask_rep, 1],
            color='gray',
            marker='o',
            s=150,
            alpha=1.0,
            linewidth=1.0,
            label='Two-Stage representatives',
            zorder=3
        )

        rep_id_to_overall_num = {rep_id: i + 1 for i, rep_id in enumerate(representative_idxs)}
        rep_indices_in_set = np.where(mask_rep)[0]

        for idx in rep_indices_in_set:
            user_id = user_ids_np[idx]
            if user_id in rep_id_to_overall_num:
                rep_num = rep_id_to_overall_num[user_id]
                plt.text(
                    human_embeddings_2d[idx, 0],
                    human_embeddings_2d[idx, 1],
                    str(rep_num),
                    ha='center',
                    va='center',
                    color='white',
                    fontweight='bold',
                    fontsize=8,
                    zorder=20
                )

    agent_iterations_2d_coords_np = None
    if agent_embeddings_2d:
        valid_embeddings_for_vstack = []
        valid_embeddings_for_vstack = []
        for emb in agent_embeddings_2d:
            if emb is not None and hasattr(emb, 'flatten') and emb.ndim > 0 :
                 valid_embeddings_for_vstack.append(emb.flatten())

        if valid_embeddings_for_vstack:
            try:
                agent_iterations_2d_coords_np = np.vstack(valid_embeddings_for_vstack)
            except ValueError as e:
                print(f"WARNING: Could not vstack agent_embeddings_2d for agent {agent_type} due to shape mismatch or empty list after filtering: {e}")
                agent_iterations_2d_coords_np = None # Ensure it's None if vstack fails
        else:
            print(f"WARNING: No valid agent embeddings found to vstack for agent {agent_type} in plot_individual_agent_2d.")


    # Plot agent embeddings
    if agent_iterations_2d_coords_np is not None and agent_iterations_2d_coords_np.shape[0] > 0:
        n_iterations = agent_iterations_2d_coords_np.shape[0]

        if 'robust' in agent_type:
            m_size = 350
        else:
            m_size = 250
        for i in range(n_iterations):
            agent_x, agent_y = agent_iterations_2d_coords_np[i, 0], agent_iterations_2d_coords_np[i, 1]
            sns.scatterplot(
                x=[agent_x],
                y=[agent_y],
                color=AGENT_PLOT_STYLES[agent_type]['color'],
                marker=AGENT_PLOT_STYLES[agent_type]['marker'],
                s=m_size,
                alpha=0.8,
                label=f'{AGENT_DISPLAY_NAMES[agent_type]}' if i == 0 else None,
                edgecolor='black',
                linewidth=0.5,
                zorder=5
            )

            plt.text(
                agent_x,
                agent_y,
                str(i+1),
                ha='center',
                va='center',
                color=AGENT_PLOT_STYLES[agent_type].get('text_color', 'black'),
                fontweight='bold',
                fontsize=10,
                zorder=11
            )

    print(f"DEBUG: Attempting to draw arrows for {agent_type}. Specified arrow iterations: {arrow_iterations}")
    arrows_drawn = 0
    if arrow_iterations is not None and agent_iterations_2d_coords_np is not None and agent_iterations_2d_coords_np.shape[0] > 0: # Check if agent 2D coords exist
        for human_idx in range(num_total_humans):
            iter_assigned = assigned_iteration_index[human_idx] 

            if iter_assigned != -1 and iter_assigned < agent_iterations_2d_coords_np.shape[0]:
                
                should_draw_arrow_for_this_iter = True 
                if arrow_iterations and (iter_assigned + 1) not in arrow_iterations:
                    should_draw_arrow_for_this_iter = False
                
                if not should_draw_arrow_for_this_iter:
                    continue

                human_x, human_y = human_embeddings_2d[human_idx, 0], human_embeddings_2d[human_idx, 1]
                agent_x, agent_y = agent_iterations_2d_coords_np[iter_assigned, 0], agent_iterations_2d_coords_np[iter_assigned, 1]

                if not (np.isnan(human_x) or np.isnan(human_y) or np.isnan(agent_x) or np.isnan(agent_y)):
                    arrow_color = 'black'
                    
                    plt.annotate("",
                                    xy=(agent_x, agent_y), xycoords='data',
                                    xytext=(human_x, human_y), textcoords='data',
                                    arrowprops=dict(arrowstyle="->",
                                                    connectionstyle="arc3,rad=0.1",
                                                    color=arrow_color,
                                                    lw=0.7, 
                                                    alpha=0.4, 
                                                    shrinkA=5,
                                                    shrinkB=5),
                                    zorder=2)
                    arrows_drawn +=1
    if arrows_drawn > 0:
        print(f"DEBUG: Drew {arrows_drawn} arrows for {agent_type}")
    elif num_total_humans > 0 and agent_iterations_2d_coords_np is not None and agent_iterations_2d_coords_np.shape[0] > 0 : # Only print if there were humans and agents to potentially draw arrows for
            print(f"DEBUG: No arrows drawn for {agent_type} (num_total_humans: {num_total_humans}, assigned_iteration_index non -1 sum: {np.sum(assigned_iteration_index != -1)})")

    if assigned_iteration_index is not None and num_total_humans > 0 and len(current_agent_clustering_dim_embeddings_list) > 0:
        human_assignments_to_agent_iterations = {}
        num_agent_iterations_for_assignment = len(current_agent_clustering_dim_embeddings_list)

        for iter_idx in range(num_agent_iterations_for_assignment):
            cluster_key = f"cluster_{iter_idx + 1}"
            human_assignments_to_agent_iterations[cluster_key] = []

        for human_idx in range(num_total_humans):
            iter_assigned = assigned_iteration_index[human_idx]
            if iter_assigned != -1 and iter_assigned < num_agent_iterations_for_assignment:
                human_id = user_ids_stripped[human_idx]
                cluster_key_for_human = f"cluster_{iter_assigned + 1}"
                human_assignments_to_agent_iterations[cluster_key_for_human].append(human_id)
        
        assignments_filename = os.path.join(output_dir, f'{agent_type}_{dataset_split}_human_assignments_to_agent_iters.json')
        try:
            with open(assignments_filename, 'w') as f:
                json.dump(human_assignments_to_agent_iterations, f, indent=4)
            print(f"Saved human assignments for {agent_type} on {dataset_split} split to {assignments_filename}")
        except Exception as e:
            print(f"Error saving human assignments for {agent_type} to {assignments_filename}: {e}")


    plt.xticks([])
    plt.yticks([])
    plt.legend(loc='upper right', ncol=1)
    plt.tight_layout()

    plot_path = os.path.join(output_dir, f'embeddings_2d_{agent_type}_{dataset_split}.pdf')
    plt.savefig(plot_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Saved {agent_type} embeddings plot to {plot_path}")