import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import dmi 
from itertools import combinations

#TODO: add NDCG 

def get_mi(kernel_matrix, reward_vec, Q=None, F=None):
    if Q is not None and F is None:
        kernel_inv = np.linalg.inv(kernel_matrix + 1e-6 * np.eye(kernel_matrix.shape[0]))
        mi = -np.log(1 - reward_vec.T.dot(kernel_inv).dot(reward_vec) * 1/(Q.dot(Q.T)))
        #print(reward_vec.T.dot(kernel_inv).dot(reward_vec) * 1/(Q.dot(Q.T)))
    elif Q is not None and F is not None:
        kernel_inv = np.linalg.inv(kernel_matrix + 1e-6 * np.eye(kernel_matrix.shape[0]))
        mi = np.linalg.slogdet((
            Q.dot(np.eye(F.shape[1]) - F.T.dot(kernel_inv).dot(F)).dot(Q.T) + 1e-3 * np.eye(Q.shape[0])
        ))
    mi_exp = np.exp(mi)
    return mi, mi_exp

def get_mi_from_Q(kernel_matrix, Q): 
    mi_log = np.linalg.slogdet(Q.T @ kernel_matrix @ Q + 1e-3 * np.eye(Q.shape[1]))[1]
    mi = np.exp(mi_log)
    return mi

def get_trace_from_Q(kernel_matrix, Q, F):
    P_S = F.T @ np.linalg.inv(kernel_matrix + 1e-6 * np.eye(kernel_matrix.shape[0])) @ F
    trace = np.trace(Q @ P_S @ Q.T + 1e-3 * np.eye(Q.shape[0]))
    return trace

def compute_cholesky_norm(kernel_matrix, reward_vec): 
    V = np.linalg.cholesky(kernel_matrix + 1e-6 * np.eye(kernel_matrix.shape[0]))
    tilde_V = np.linalg.inv(V + 1e-6 * np.eye(kernel_matrix.shape[0]))
    norm = ((tilde_V.dot(reward_vec)) ** 2).sum() ** (1/2) 

    return norm 

def get_det_from_kernel(kernel_matrix):
    # Compute the determinant of the kernel matrix
    det = np.linalg.det(kernel_matrix + 1e-3 * np.eye(kernel_matrix.shape[0]))
    return det

def compute_log_det(kernel_matrix): 
    return np.linalg.slogdet(kernel_matrix + 1e-3 * np.eye(kernel_matrix.shape[0]))[1]

def compute_mean_dist(distance_vec): 
    return np.mean(distance_vec)

def compute_ilad(sim_matrix): 
    ilad = np.mean(1 - np.tril(sim_matrix, -1))
    return ilad 

def compute_ilmd(sim_matrix): 
    ilmd = np.min(1 - np.tril(sim_matrix, -1))
    return ilmd 

def compute_mrr(select_idx, ori_scores):
    scores = np.asarray(ori_scores.squeeze())
    n = ori_scores.shape[0]

    # Sort scores descending, get sorted indices
    sorted_idx = np.argsort(-scores)
    ori_ranks = np.argsort(sorted_idx)
    sorted_scores = scores[sorted_idx]

    # Prepare array for average ranks
    ranks = np.zeros(n, dtype=float)
    i = 0
    while i < n:
        j = i
        # Find the range of ties
        while j + 1 < n and abs(sorted_scores[j+1] - sorted_scores[i]) <= 1e-8:
            j += 1
        # Assign average rank to all tied positions
        avg_rank = (i + j) / 2 + 1  # +1 to make ranks 1-based
        ranks[sorted_idx[i:j + 1]] = avg_rank
        i = j + 1

    # now just pick the smallest rank among your selected indices
    select_ranks = ranks[select_idx]
    best_rank = select_ranks.min()
    return 1.0 / best_rank


def generate_data(
        data_type, 
        item_size, 
        feature_dimension,
        data_distribution="normal",
        test_type="base", # Added test_type
        num_user_interactions=10 # For recommendation type
    ):
    
    if test_type == "base":
        if data_type == "iid": 
            # ... existing base iid logic ...
            if data_distribution == "normal":
                q_vectors = np.random.randn(1, feature_dimension)
                feature_vectors = np.random.randn(item_size, feature_dimension)
            else: 
                q_vectors = np.random.rand(1, feature_dimension)
                feature_vectors = np.random.rand(item_size, feature_dimension)
            
        elif data_type == "colinear": 
            # ... existing base colinear logic ...
            if data_distribution == "normal": 
                q_vectors = np.random.randn(1, feature_dimension)
                feature_vectors = np.random.randn(item_size // 2, feature_dimension)
            else:
                q_vectors = np.random.rand(1, feature_dimension)
                feature_vectors = np.random.rand(item_size // 2, feature_dimension)
            feature_vectors = np.concatenate([feature_vectors, feature_vectors + np.random.randn(item_size // 2, feature_dimension) * 2], axis=0)

        # Normalize and calculate base metrics
        q_vectors /= np.linalg.norm(q_vectors, axis=1, keepdims=True)
        feature_vectors /= np.linalg.norm(feature_vectors, axis=1, keepdims=True)
        scores = q_vectors.dot(feature_vectors.T)
        scores = scores.reshape((item_size, 1))
        scores = (scores + 1) / 2 # Normalize to [0, 1]
        
        similarities = np.dot(feature_vectors, feature_vectors.T)
        if data_distribution == "normal": # Assuming cosine similarity needs adjustment if not [-1, 1]
             similarities = (similarities + 1) / 2 
        kernel_matrix = scores.reshape((item_size, 1)) * similarities * scores.reshape((1, item_size))

        return {
            "q_vectors": q_vectors, 
            "feature_vectors": feature_vectors, 
            "scores": scores, 
            "similarities": similarities, 
            "kernel_matrix": kernel_matrix
        }

    elif test_type == "recommendation":
        # Generate user interaction vectors p_u
        if data_distribution == "normal":
            p_u = np.random.randn(num_user_interactions, feature_dimension)
        else:
            p_u = np.random.rand(num_user_interactions, feature_dimension)
        p_u /= np.linalg.norm(p_u, axis=1, keepdims=True)
        
        # Split into train/test
        test_item = p_u[-1:] # Shape (1, d)
        train_items = p_u[:-1] # Shape (m-1, d)

        # Generate candidate vectors 'cands'
        if data_distribution == "normal":
            cands = np.random.randn(item_size, feature_dimension)
        else:
            cands = np.random.rand(item_size, feature_dimension)
        cands /= np.linalg.norm(cands, axis=1, keepdims=True) # Shape (n, d)

        # Calculate scores: Average similarity between training items and candidates
        if train_items.shape[0] > 0:
             # Ensure train_items is 2D
            train_items_norm = train_items / np.linalg.norm(train_items, axis=1, keepdims=True)
            scores = np.mean(train_items_norm @ cands.T, axis=0).reshape(-1, 1) # Shape (n, 1)
        else:
             scores = np.zeros((item_size, 1)) # Handle case with only one interaction item

        # Calculate similarities among candidates
        similarities = cands @ cands.T # Shape (n, n)
        if data_distribution == "normal": # Assuming cosine similarity needs adjustment
             similarities = (similarities + 1) / 2

        # Kernel matrix for candidates based on derived scores and similarities
        kernel_matrix = scores * similarities * scores.T

        return {
            "train_items": train_items,
            "test_item": test_item,
            "candidate_vectors": cands,
            "scores": scores, # Relevance of candidates to train items
            "similarities": similarities, # Similarity among candidates
            "kernel_matrix": kernel_matrix # Kernel for candidates
        }
    else:
        raise ValueError(f"Unknown test_type: {test_type}")

def dcg_at_k(relevances, k):
    relevances = np.asarray(relevances)[:k]
    if relevances.size:
        return np.sum(relevances / np.log2(np.arange(2, relevances.size + 2)))
    return 0.

def ndcg_at_k(relevances, k):
    dcg_max = dcg_at_k(sorted(relevances, reverse=True), k)
    if not dcg_max:
        return 0.
    return dcg_at_k(relevances, k) / dcg_max

def compute_ndcg(selected_indices, candidate_vectors, test_item_vector, k):
    """
    Computes nDCG@k for selected candidate items based on similarity to a test item.

    Args:
        selected_indices (list): Indices of the selected items from the candidate set.
        candidate_vectors (np.ndarray): Embeddings of all candidate items (n, d).
        test_item_vector (np.ndarray): Embedding of the test item (1, d).
        k (int): The number of items selected (top k for nDCG).

    Returns:
        float: nDCG@k score.
    """
    if candidate_vectors.shape[0] == 0:
        return 0.0
        
    # Ensure vectors are normalized for cosine similarity
    test_item_norm = test_item_vector / np.linalg.norm(test_item_vector)
    candidate_norms = candidate_vectors / np.linalg.norm(candidate_vectors, axis=1, keepdims=True)
    
    # Calculate cosine similarities (relevance scores) between test item and all candidates
    all_relevances = (test_item_norm @ candidate_norms.T).flatten()
    all_relevances -= np.min(all_relevances) # Shift to non-negative
    all_relevances /= np.max(all_relevances) # Normalize to [0, 1]
    
    # Get relevance scores for the selected items in their selected order
    selected_relevances = all_relevances[selected_indices]
    
    # Calculate nDCG
    return ndcg_at_k(selected_relevances, k)

def run_simulate(
        round, 
        max_length=10, 
        item_size=20, # Number of candidate items for recommendation
        feature_dimension=20, 
        data_type="iid",
        do_bf=False,
        data_distribution="normal",
        quality_metric="mrr", # Default for base, will use ndcg for recommendation
        diversity_metric="ilad",
        do_plot=True,
        do_numba=True,
        block_num=10,
        test_type="base", # Added test_type parameter
        num_user_interactions=10 # Added for recommendation
    ):

    metrics = {}
    selects = {}
    cxts = [] # Context might need adjustment based on test_type if used later
    
    # Determine metrics based on test_type
    if test_type == "base":
        active_metrics = ["mi", "log_det", "mean_dist", "ilad", "ilmd", "mrr", "cho_norm"]
        quality_metric = "mrr" # Override if needed for base
    elif test_type == "recommendation":
        active_metrics = ["ndcg", "log_det", "ilad", "ilmd"]
        quality_metric = "ndcg" # Use nDCG for recommendation quality
    else:
        raise ValueError(f"Unknown test_type: {test_type}")

    for ri in range(round):
        # Generate data based on test type
        data = generate_data(
            data_type, item_size, feature_dimension, data_distribution, 
            test_type=test_type, num_user_interactions=num_user_interactions
        )

        # Extract data based on type
        if test_type == "base":
            q_vectors = data["q_vectors"]
            feature_vectors = data["feature_vectors"]
            rel_scores = data["scores"]
            similarities = data["similarities"]
            kernel_matrix = data["kernel_matrix"]
            current_item_size = feature_vectors.shape[0]
        elif test_type == "recommendation":
            train_items = data["train_items"]
            test_item = data["test_item"]
            feature_vectors = data["candidate_vectors"] # Use candidates as features for algos
            rel_scores = data["scores"] # Use relevance scores derived from train items
            similarities = data["similarities"] # Use similarities among candidates
            kernel_matrix = data["kernel_matrix"] # Use kernel derived for candidates
            current_item_size = feature_vectors.shape[0]
 
        algos = ["dmi", "dpp", "dpp_no_score", "dmi_block", "dmi_omp", "dmi_omp_sp", "random", "dmi_omp_aug"]
        if do_numba: 
            algos += ["dmi_numba"]

        if do_bf and test_type == "base": # BF might be too slow/complex for recommendation setup initially
            algos += ["bf"]
            # algos += [f"bf_alpha_{alpha}" for alpha in [0.5, 1, 2, 5, 10, 20]] # Keep BF alpha for base if needed
        
        # DPP alpha variants might need rethinking for recommendation scores
        algos += [f"dpp_alpha_{alpha}" for alpha in [0.5, 1, 2, 5, 10, 20]] 
        
        scores = rel_scores.copy() # Use appropriate scores (base rel_scores or recommendation derived scores)
        
        for algo in algos:
            # --- Algorithm Execution ---
            # Ensure algorithms use the correct feature_vectors, scores, similarities based on test_type
            # For recommendation: feature_vectors=cands, scores=relevance_to_train, similarities=among_cands
            if algo == "dmi_numba":
                # Needs DMIModelV2 or similar if used for recommendation
                if test_type == "base":
                    dmi_model = dmi.DMIModelV2(similarities, scores, max_length) # Assuming V2 works for base
                # Numba function needs correct inputs for recommendation
                res, t_v = dmi.compute_selection_v3_numba(similarities, scores, max_length)
                res = res.tolist()
            elif algo == "dmi":
                # DMIModel needs correct inputs (feature_vectors are candidates for recommendation)
                dmi_model = dmi.DMIModel(None, scores, feature_vectors)
                res, _ = dmi_model.compute_selection(
                    similarities, 
                    scores,
                    max_length
                )
            elif algo == "dmi_omp":
                # OMP needs candidate features for recommendation
                dmi_model = dmi.OrthogonalMatchingPursuit(scores, feature_vectors)
                res = dmi_model.compute_selection(max_length)
                t_v = None 
            elif algo == "dmi_omp_aug": 
                # OMP Augmented needs candidate features for recommendation
                dmi_model = dmi.OrthogonalMatchingPursuitAug(scores, feature_vectors)
                res = dmi_model.compute_selection(max_length)
                t_v = None
            elif algo == "dmi_omp_sp":
                 # OMP SelfProj needs candidate features for recommendation
                # Assuming OrthogonalMatchingPursuitSelfProj exists and takes similar args
                # dmi_model = dmi.OrthogonalMatchingPursuitSelfProj(scores, feature_vectors) 
                # res = dmi_model.compute_selection(max_length)
                # t_v = None
                # Placeholder if SelfProj doesn't exist or needs update
                print(f"Skipping {algo} as SelfProj variant needs check/implementation")
                continue 
            elif algo == "dmi_block":
                 # DMIModel block needs candidate features for recommendation
                dmi_model = dmi.DMIModel(None, scores, feature_vectors)
                res, t_v = dmi_model.block_compute_selection(block_num, max_length)
            elif algo == "dpp": 
                # DPP uses kernel_matrix (derived from candidate scores/sim for recommendation)
                dpp_model = dmi.DPPModel(kernel_matrix, max_length)
                res = dpp_model.compute_selection()
            elif algo == "dpp_no_score":
                # DPP no score uses candidate similarities for recommendation
                dpp_model = dmi.DPPModel(similarities, max_length)
                res = dpp_model.compute_selection()
            elif algo.startswith("dpp_alpha"): 
                alpha = float(algo.split("_")[2])
                # Use recommendation scores if applicable
                exp_scores = np.exp(alpha * scores) 
                # Use candidate similarities
                cur_kernel_matrix = exp_scores * similarities * exp_scores.T
                dpp_model = dmi.DPPModel(cur_kernel_matrix, max_length)
                res = dpp_model.compute_selection()
            elif do_bf and algo == "bf": # Only for base type currently
                combs = list(combinations(range(current_item_size), max_length))
                combs = [list(c) for c in combs]
                # BF needs q_vectors for base MI calculation
                comb_res = [get_mi(similarities[comb, :][:, comb], scores[comb], Q=q_vectors) for comb in combs]
                res = combs[np.argmax([c[0] for c in comb_res])]
            # elif do_bf and algo.startswith("bf_alpha"): # Only for base type currently
            #     alpha = float(algo.split("_")[2])
            #     exp_scores = np.exp(alpha * scores)
            #     combs = [list(c) for c in combs]
            #     comb_res = [get_mi(similarities[comb, :][:, comb], exp_scores[comb], Q=q_vectors) for comb in combs]
            #     res = combs[np.argmax([c[0] for c in comb_res])]
            elif algo == "random":
                res = np.random.choice(current_item_size, max_length, replace=False)
                res = res.tolist()
            else:
                 print(f"Skipping unrecognized or BF algo for recommendation: {algo}")
                 continue


            # --- Metric Calculation ---
            current_metrics = {}
            if not res: # Handle empty selection
                 print(f"Warning: Algo {algo} returned empty selection.")
                 for metric_k in active_metrics:
                     current_metrics[metric_k] = np.nan # Or 0.0
            else:
                # Get features/similarities of selected items (candidates for recommendation)
                select_feature_vecs = feature_vectors[res]
                select_kernel_matrix = similarities[res, :][:, res]
                select_reward_vec = scores[res] # Scores used by the algo

                # Calculate metrics based on test_type
                if test_type == "base":
                    _, mi = get_mi(select_kernel_matrix, select_reward_vec, Q=q_vectors)
                    cho_norm = compute_cholesky_norm(select_kernel_matrix, select_reward_vec)
                    mean_dist = compute_mean_dist(select_reward_vec) # Mean score of selected
                    mrr = compute_mrr(res, scores) # MRR based on original scores
                    current_metrics["mi"] = mi
                    current_metrics["cho_norm"] = cho_norm
                    current_metrics["mean_dist"] = mean_dist
                    current_metrics["mrr"] = mrr
                    # Common diversity metrics
                    current_metrics["log_det"] = compute_log_det(select_kernel_matrix)
                    current_metrics["ilad"] = compute_ilad(select_kernel_matrix)
                    current_metrics["ilmd"] = compute_ilmd(select_kernel_matrix)

                elif test_type == "recommendation":
                    # Quality metric: nDCG based on test item similarity
                    ndcg = compute_ndcg(res, feature_vectors, test_item, 3)
                    current_metrics["ndcg"] = ndcg
                    # Diversity metrics (based on selected candidates)
                    current_metrics["log_det"] = compute_log_det(select_kernel_matrix)
                    current_metrics["ilad"] = compute_ilad(select_kernel_matrix)
                    current_metrics["ilmd"] = compute_ilmd(select_kernel_matrix)

            # Store results
            if algo not in metrics:
                metrics[algo] = []
                selects[algo] = []
            
            # Ensure all active metrics are present, fill with NaN if calculation failed
            for metric_k in active_metrics:
                 if metric_k not in current_metrics:
                     current_metrics[metric_k] = np.nan
            
            metrics[algo].append(current_metrics)
            selects[algo].append(res)

    # --- Aggregation and Plotting ---
    # cxts might need adjustment if used later
    # cxts.append(data) # Store the last generated data context

    # Calculate mean and standard deviation
    stats = {}
    valid_algos = list(metrics.keys()) # Use algos that actually ran
    if not valid_algos:
        print("No algorithms were successfully run.")
        return {}, {}, {}, {}, None

    first_algo_metrics = metrics[valid_algos[0]]
    if not first_algo_metrics:
         print("No metrics recorded for the first algorithm.")
         return {}, {}, {}, {}, None
         
    metric_keys = list(first_algo_metrics[0].keys()) # Get keys from the first record

    for algo in valid_algos:
        if algo not in stats: 
            stats[algo] = {}
        
        for metric_k in metric_keys:
            stats[algo][metric_k] = {}
            metric_values = [m.get(metric_k, np.nan) for m in metrics[algo]] # Safely get values
            stats[algo][metric_k]["mean"] = np.nanmean(metric_values)
            stats[algo][metric_k]["std"] = np.nanstd(metric_values)

    # Plotting (uses quality_metric determined by test_type)
    if do_plot: 
        plot_across_metrics(metrics, stats, valid_algos)
        plot_qual_diversity_trade_off(metrics, stats, valid_algos, quality_metric, diversity_metric)

    # Return t_v from the last run if needed, might not be meaningful across types/algos
    last_t_v = None # Placeholder, adjust if t_v is needed and makes sense to return

    return metrics, selects, stats, cxts, last_t_v 


def plot_across_metrics(metrics, stats, algos): 
    for metric_k in metrics[algos[0]][0].keys():
        sns.set(style="whitegrid")
        fig, ax = plt.subplots()
        # Prepare data for plotting
        palette = sns.color_palette("husl", len(algos))
        for i, algo in enumerate(algos):
            if algo.startswith("dpp_") and "alpha" in algo and metric_k != "mi":
                continue
            means = stats[algo][metric_k]["mean"]
            errors = stats[algo][metric_k]["std"]
            ax.bar(algo, means, yerr=errors, capsize=10, label=algo, color=palette[i])
        ax.set_ylabel('Scores')
        ax.set_title(f'{metric_k} Comparison of DMI and DPP')
        ax.set_xticklabels(algos, rotation=90)
        plt.tight_layout()
    plt.legend()

def plot_qual_diversity_trade_off(
        metrics, 
        stats, 
        algos,
        quality_metric,
        diversity_metric
    ): 
    sns.set(style="whitegrid")
    fig, ax = plt.subplots()
    # Prepare data for plotting
    palette = sns.color_palette("husl", len(algos))
    for i, algo in enumerate(algos):
        #labels = [f"MI{i}" for i in range(1, max_length+1)]
        x_means = stats[algo][quality_metric]["mean"]
        y_means = stats[algo][diversity_metric]["mean"]
        x_errors = stats[algo][quality_metric]["std"]
        y_errors = stats[algo][diversity_metric]["std"]
        ax.plot(x_means, y_means, color=palette[i])
        ax.errorbar(x_means, y_means, xerr=x_errors, yerr=y_errors, fmt='o', capsize=5, label=algo, color=palette[i])
    ax.set_xlabel(quality_metric)
    ax.set_ylabel(diversity_metric)
    ax.set_title(f'quality vs diversity Comparison of DMI and DPP')
    plt.tight_layout()
    plt.legend()
    plt.show()