# -*- coding: utf-8 -*-
"""
Created on Wed Jul  9 15:09:48 2025

@author: baran
"""

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 29 16:35:08 2025

@author: baran
"""
from transformers import AutoTokenizer
k = 10
t = 100
delta = 0.8
epsilon_greedy_eps = 0.1
tokenizer = AutoTokenizer.from_pretrained("gpt2")
from sentence_transformers import SentenceTransformer
import random
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
distinct_queries = [
        "weather",
        "weather in New York City for the weekend",
        "NBA",
        "latest updates on the NBA playoffs",
        "AAPL",
        "current stock price and news for Apple Inc",
        "cheap flights",
        "compare flight prices from NYC to Paris",
        "Python",
        "comprehensive Python tutorial for beginners",
        "bitcoin",
        "current trends and forecasts for Bitcoin",
        "tennis scores",
        "who won the latest tennis grand slam final",
        "AI papers",
        "recent breakthroughs in AI and machine learning research",
        "Rome attractions",
        "top 10 tourist attractions and hidden gems in Rome",
        "guitar lessons online",
        "how to learn to play acoustic guitar from scratch online",


        # "best pizza recipes",
        # "quick vegetarian dinner ideas",
        # "Marvel movies ranked",
        # "when is the next Marvel movie release",
        # "symptoms of the common cold",
        # "how to boost your immune system naturally",
        # "USD to EUR exchange rate",
        # "crypto wallet security tips",
        # "Netflix top shows this week",
        # "is The Witcher renewed for another season",
        # "daily meditation techniques",
        # "how to start mindfulness practice",
        # "beginner workout routine at home",
        # "calorie calculator for weight loss",
        # "cheap hotels in Tokyo",
        # "travel visa requirements for Japan",
        # "Elon Musk latest news",
        # "Tesla Model 3 vs Model Y comparison",
        # "top programming languages in 2025",
        # "difference between machine learning and AI"
        # "best budget smartphones 2025",
        # "how to block spam calls on iPhone",
        # "COVID-19 latest variant symptoms",
        # "how to file taxes online in the US",
        # "resume tips for data science roles",
        # "interview questions for software engineers",
        # "stock market news today",
        # "best apps for learning Spanish",
        # "who is leading in the US presidential polls",
        # "does drinking coffee help with focus"
    ]
query_stats = []
query_to_id = {q: i for i, q in enumerate(distinct_queries)}
id_to_query = {i: q for q, i in query_to_id.items()}
query_stream_ids = [query_to_id[q] for q in distinct_queries]
token_lengths = [len(tokenizer.encode(q)) for q in distinct_queries]
# normalize = list(1/100*np.ones(no_rounds))
#costs_stream = [a * b for a, b in zip(token_lengths, normalize)]
min_len = min(token_lengths)
max_len = max(token_lengths)
costs_stream = [(l - min_len) / (max_len - min_len) for l in token_lengths]
cache_size = k
current_cache_ids = random.sample(range(len(distinct_queries)), cache_size)
true_cost_dict = {}
for i in range(len(distinct_queries)):
    q_id = query_stream_ids[i]
    #C_qt.append(0.0 if q_id in cache_ids else costs_stream[i])
    sigma = 0.05  # tune this value for how noisy you want the cost to be
    true_cost = costs_stream[i]
    true_cost_dict[q_id] = true_cost
# for round_idx in range(no_rounds):
#     q_id = query_stream_ids[round_idx]
#     if q_id not in current_cache_ids:
#         if random.random() < 0.5:
#             current_cache_ids[random.randint(0, cache_size - 1)] = q_id
#     cache_history_ids.append(current_cache_ids.copy())

for q_id in range(len(distinct_queries)):
    true_cost = true_cost_dict[q_id]
    query_stats.append({
        "query": id_to_query[q_id],
        "c(q)": round(true_cost, 4)
    })

query_stats_dict = {entry["query"]: entry for entry in query_stats}
from itertools import combinations
def brute_force_optimal_cache(k, epsilon, l2_distance_matrix, query_stats_dict, id_to_query):
    m = len(l2_distance_matrix)
    best_loss = float("inf")
    best_cache = None
    all_indices = range(m)

    for combo in combinations(all_indices, k):
        loss_val, _ = compute_loss_and_avg_dist_actual_prob(
            M_set=combo,
            epsilon=epsilon,
            l2_distance_matrix=l2_distance_matrix,
            query_stats_dict=query_stats_dict,
            id_to_query=id_to_query
        )
        if loss_val < best_loss:
            best_loss = loss_val
            best_cache = combo

    return best_loss, set(best_cache)

def compute_loss_and_avg_dist_actual_prob(M_set, epsilon, l2_distance_matrix, query_stats_dict, id_to_query):
    m = len(l2_distance_matrix)
    actual_p = 1 / m
    loss_val = 0.0
    covered_distances = []

    for q in range(m):
        dists = [l2_distance_matrix[q][m_id] for m_id in M_set]
        min_dist = min(dists)
        query = id_to_query[q]

        if min_dist <= epsilon:
            loss_val += actual_p * min_dist
            covered_distances.append(min_dist)
        else:
            loss_val += actual_p * query_stats_dict[query]["c(q)"]

    avg_covered_dist = np.mean(covered_distances) if covered_distances else 0.0
    return round(loss_val, 4), round(avg_covered_dist, 4)
import numpy as np
epsilon_values = np.linspace(0, 1, 100)

emb_queries = model.encode(distinct_queries)
query_to_emb = dict(zip(distinct_queries, emb_queries))
import pandas as pd
from sklearn.metrics.pairwise import euclidean_distances
l2_distance_matrix = euclidean_distances(emb_queries)
l2_df = pd.DataFrame(
    l2_distance_matrix,
    index=distinct_queries,
    columns=distinct_queries
)
min_val = np.min(l2_distance_matrix)
max_val = np.max(l2_distance_matrix)
normalized_l2 = (l2_distance_matrix - min_val) / (max_val - min_val)
l2_distance_matrix = normalized_l2

brute_force_cache = {}
for eps in epsilon_values:
    if eps not in brute_force_cache:
        loss_opt, opt_cache = brute_force_optimal_cache(k, eps, l2_distance_matrix,query_stats_dict,id_to_query)
        brute_force_cache[eps] = (loss_opt, opt_cache)
    else:
        loss_opt, opt_cache = brute_force_cache[eps]



def run_single_trial(k,t,delta,epsilon_greedy_eps,seed=None):
    import random
    import numpy as np

    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        print(f"Starting run: {seed}")

    import pandas as pd
    # Define 20 distinct queries in related clusters


    assert len(distinct_queries) == 20
    no_rounds = t
    query_stream = [random.choice(distinct_queries) for _ in range(no_rounds)]
    cache_size = k
    k = cache_size

    delta = delta
    #epsilon = 0.5
    epsilon_greedy_eps = epsilon_greedy_eps
    #cache_history = []
    current_cache = random.sample(distinct_queries, cache_size)
    #query_to_id = {q: i for i, q in enumerate(distinct_queries)}
    #id_to_query = {i: q for q, i in query_to_id.items()}

    query_stream_ids = [query_to_id[q] for q in query_stream]
    cache_history_ids = []
    current_cache_ids = set(random.sample(range(len(distinct_queries)), cache_size))

    for round_idx in range(no_rounds):
        q_id = query_stream_ids[round_idx]
        if q_id not in current_cache_ids:
            if random.random() < 0.5:
                current_cache_ids.discard(random.choice(list(current_cache_ids)))
                current_cache_ids.add(q_id)
        cache_history_ids.append(list(current_cache_ids).copy())


    token_lengths = [len(tokenizer.encode(q)) for q in query_stream]
    # normalize = list(1/100*np.ones(no_rounds))
    #costs_stream = [a * b for a, b in zip(token_lengths, normalize)]
    min_len = min(token_lengths)
    max_len = max(token_lengths)
    costs_stream = [(l - min_len) / (max_len - min_len) for l in token_lengths]
    from collections import Counter
    query_counts = Counter(query_stream)

    # Convert to DataFrame and sort by frequency
    summary_df = pd.DataFrame(query_counts.items(), columns=["query", "count"]).sort_values(by="count", ascending=False).reset_index(drop=True)
    emb_stream = model.encode(query_stream)
    emb_queries = model.encode(distinct_queries)
    query_to_emb = dict(zip(distinct_queries, emb_queries))

    from sklearn.metrics.pairwise import euclidean_distances
    l2_distance_matrix = euclidean_distances(emb_queries)
    l2_df = pd.DataFrame(
        l2_distance_matrix,
        index=distinct_queries,
        columns=distinct_queries
    )
    min_val = np.min(l2_distance_matrix)
    max_val = np.max(l2_distance_matrix)
    normalized_l2 = (l2_distance_matrix - min_val) / (max_val - min_val)
    l2_distance_matrix = normalized_l2
    C_qt = []
    true_cost_dict_stream = {}
    for i in range(no_rounds):
        q_id = query_stream_ids[i]
        #C_qt.append(0.0 if q_id in cache_ids else costs_stream[i])
        sigma = 0.05  # tune this value for how noisy you want the cost to be
        true_cost = costs_stream[i]
        true_cost_dict_stream[i] = true_cost # Store true cost for each query in the stream
        #noisy_cost = np.clip(np.random.normal(loc=true_cost, scale=sigma), 0.0, 1.0)
        #C_qt.append(0.0 if q_id in cache_ids else noisy_cost)
        # if q_id not in cache_ids:
        #     noisy_cost = np.clip(np.random.normal(loc=true_cost, scale=sigma), 0.0, 1.0)
        #     C_qt.append(noisy_cost)
        # else:
        #     C_qt.append(0.0)


    df = pd.DataFrame({
        "round": range(1, no_rounds+1),
        "query": [id_to_query[q_id] for q_id in query_stream_ids],
        "cache": [[id_to_query[c_id] for c_id in cache] for cache in cache_history_ids],
        "cost": C_qt
    })
    #df.to_csv("dataset_D.csv", index=False)
    distinct_queries_stream = df['query'].unique().tolist()
    n = len(df)
    m = len(distinct_queries) # Use the total number of distinct queries

    from collections import defaultdict
    import math
    N_counts = defaultdict(int)
    Nc_counts = defaultdict(int)
    sum_costs = defaultdict(float)
    for i, row in df.iterrows():
        q = row['query']
        cost = row['cost']
        q_id = query_to_id[q]
        N_counts[q_id] += 1
        if cost > 0:
            Nc_counts[q_id] += 1
            sum_costs[q_id] += cost
    query_stats = []
    for q_id in range(m):
        q = id_to_query[q_id]
        Nq = N_counts[q_id]
        Ncq = Nc_counts[q_id]
        phat = Nq / n
        chat = sum_costs[q_id] / Ncq if Ncq > 0 else 0.0
        ucb = chat + 0.05*math.sqrt(2 * math.log(4 * m * n / delta) / Ncq) if Ncq > 0 else 1.0
        lcb = max(chat - 0.05*math.sqrt(2 * math.log(4 * m * n / delta) / Ncq),0.0) if Ncq > 0 else 0.0
        exp_bonus = 0.05*math.sqrt(2 * math.log(4 * m * n / delta) / Ncq) if Ncq > 0 else 0.0
        # print(f'chat: {chat}')
        # print(f'exp bonus: {exp_bonus}')
        # print(f'ucb: {ucb}')
        # print(f'lcb: {lcb}')
        # q_id = query_to_id[q]
        # true_cost = true_cost_dict[q_id]
        query_stats.append({
            "query": q,
            "N(q)": Nq,
            "Nc(q)": Ncq,
            "p̂(q)": round(phat, 4),
            "ĉ(q)": round(chat, 4),
            "c̄(q)": round(ucb, 4),
            "č(q)": round(lcb, 4),
            "c(q)": round(true_cost_dict[q_id], 4) # Use the true cost for the distinct query
        })

    query_stats_dict = {entry["query"]: entry for entry in query_stats}


    def loss(M_set,epsilon_val):
        covered = set()
        loss_val = 0.0
        for q_id in range(m):
            is_covered = False
            min_dist = float('inf')
            for m_id in M_set:
                d = l2_distance_matrix[q_id, m_id]
                if d <= epsilon_val:
                    is_covered = True
                    min_dist = min(min_dist, d)
            if is_covered:
                loss_val += query_stats[q_id]["p̂(q)"] * min_dist
            else:
                loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["c̄(q)"]
        return loss_val





    print("Min distance:", np.min(l2_distance_matrix))
    print("Max distance:", np.max(l2_distance_matrix))
    print("Mean distance:", np.mean(l2_distance_matrix))



    def compute_cache_metrics(M_set, epsilon_thresh):
        # Unique (distinct) query stats
        covered = set()
        uncovered = set()
        covered_distances = []
        uncovered_costs = []

        for q_id in range(m):  # over distinct queries
            dists = [l2_distance_matrix[q_id, m_id] for m_id in M_set]
            min_dist = min(dists)
            if min_dist <= epsilon_thresh:
                covered.add(q_id)
                covered_distances.append(min_dist)
            else:
                uncovered.add(q_id)
                uncovered_costs.append(query_stats[q_id]["ĉ(q)"])

        coverage_ratio = len(covered) / m
        avg_covered_distance = np.mean(covered_distances) if covered_distances else 0.0
        avg_uncovered_cost = np.mean(uncovered_costs) if uncovered_costs else 0.0

        # Stream-based (repeated) stats
        stream_covered = 0
        total_saved = 0.0
        total_queries = len(query_stream_ids)

        for i in range(total_queries):
            q_id = query_stream_ids[i]
            dists = [l2_distance_matrix[q_id, m_id] for m_id in M_set]
            min_dist = min(dists)
            if min_dist <= epsilon_thresh:
                stream_covered += 1
                # total_saved += query_stats[q_id]["ĉ(q)"]
                #saved = query_stats[q_id]["ĉ(q)"] - min_dist
                # q_str = id_to_query[q_id]
                saved = true_cost_dict_stream[i] - min_dist # Use true cost from stream
                saved = max(saved, 0.0)  # avoid negative saved cost
                total_saved += saved


        stream_coverage_ratio = stream_covered / total_queries

        return {
            "Coverage (%) (unique)": round(100 * coverage_ratio, 2),
            "Coverage (%) (stream)": round(100 * stream_coverage_ratio, 2),
            "Avg Distance of Covered Queries": round(avg_covered_distance, 4),
            "Avg Cost of Uncovered Queries": round(avg_uncovered_cost, 4),
            "Total Cost Saved by Cache": round(total_saved, 4)
        }






    import matplotlib.pyplot as plt


    # query_stats_dict = {entry["query"]: entry for entry in query_stats}

    def compute_loss_and_avg_dist_actual_prob(M_set, epsilon, l2_distance_matrix, query_stats_dict, id_to_query):
        m = len(l2_distance_matrix)
        actual_p = 1 / m
        loss_val = 0.0
        covered_distances = []

        for q_id in range(m):
            dists = [l2_distance_matrix[q_id][m_id] for m_id in M_set]
            min_dist = min(dists)
            # query = id_to_query[q_id]

            if min_dist <= epsilon:
                loss_val += actual_p * min_dist
                covered_distances.append(min_dist)
            else:
                loss_val += actual_p * query_stats[q_id]["c(q)"]

        avg_covered_dist = np.mean(covered_distances) if covered_distances else 0.0
        return round(loss_val, 4), round(avg_covered_dist, 4)



    def epsilon_greedy_removal_loss_and_dist(epsilon_val, k, m, query_stats, l2_distance_matrix, epsilon_greedy_eps):
        def loss_and_dist(M_set):
            loss_val = 0.0
            covered_distances = []
            for q_id in range(m):
                dists = [l2_distance_matrix[q_id, m_id] for m_id in M_set]
                min_dist = min(dists)
                if min_dist <= epsilon_val:
                    loss_val += query_stats[q_id]["p̂(q)"] * min_dist
                    covered_distances.append(min_dist)
                else:
                    loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["ĉ(q)"]
            avg_dist = np.mean(covered_distances) if covered_distances else 0.0
            return loss_val, avg_dist

        Q_current = set(range(m))  # start with full set
        for _ in range(m - k):
            if np.random.rand() < epsilon_greedy_eps:
                # Exploration: remove a random query
                q_to_remove = random.choice(list(Q_current))
            else:
                # Exploitation: remove the one causing minimal loss increase
                best_q = None
                best_loss = float("inf")
                for q in Q_current:
                    temp_Q = Q_current - {q}
                    temp_loss, _ = loss_and_dist(temp_Q)
                    if temp_loss < best_loss:
                        best_loss = temp_loss
                        best_q = q
                q_to_remove = best_q
            Q_current.remove(q_to_remove)

        #return loss_and_dist(Q_current)
        return compute_loss_and_avg_dist_actual_prob(Q_current,epsilon_val,l2_distance_matrix,query_stats_dict,id_to_query)






    # Sweep over epsilon again
    #epsilon_values = np.linspace(2, 12, 20)
    epsilon_values = np.linspace(0, 1, 100)
    # loss_values = [
    #     epsilon_greedy_loss_for_epsilon(eps, k=5, m=m, query_stats=query_stats,
    #                                     l2_distance_matrix=l2_distance_matrix,
    #                                     epsilon_greedy_eps=0.2)
    #     for eps in epsilon_values
    # ]

    # Define reverse greedy with loss computation at varying epsilon
    def reverse_greedy_loss_and_dist_for_epsilon(epsilon_val, k, m, query_stats, l2_distance_matrix):
        def loss_and_dist(M_set):
            loss_val = 0.0
            covered_distances = []
            for q_id in range(m):
                dists = [l2_distance_matrix[q_id, m_id] for m_id in M_set]
                min_dist = min(dists)
                if min_dist <= epsilon_val:
                    loss_val += query_stats[q_id]["p̂(q)"] * min_dist
                    covered_distances.append(min_dist)
                else:
                    loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["c̄(q)"]
                    #loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["ĉ(q)"]

            avg_dist = np.mean(covered_distances) if covered_distances else 0.0
            return loss_val, avg_dist

        Q_current = set(range(m))
        def loss_ucb(M_set,epsilon_val):
            covered = set()
            loss_val = 0.0
            for q_id in range(m):
                is_covered = False
                min_dist = float('inf')
                for m_id in M_set:
                    d = l2_distance_matrix[q_id, m_id]
                    if d <= epsilon_val:
                        is_covered = True
                        min_dist = min(min_dist, d)
                if is_covered:
                    loss_val += query_stats[q_id]["p̂(q)"] * min_dist
                else:
                    loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["c̄(q)"]
            return loss_val

        for _ in range(m - k):
            best_q = None
            best_loss = float("inf")
            for q in Q_current:
                temp_Q = Q_current - {q}
                temp_loss = loss_ucb(temp_Q,epsilon_val)
                if temp_loss < best_loss:
                    best_loss = temp_loss
                    best_q = q
            Q_current.remove(best_q)

        return compute_loss_and_avg_dist_actual_prob(Q_current,epsilon_val,l2_distance_matrix,query_stats_dict,id_to_query),Q_current

    def reverse_greedy_lcb_and_dist_for_epsilon(epsilon_val, k, m, query_stats, l2_distance_matrix):
        def loss_and_dist(M_set):
            loss_val = 0.0
            covered_distances = []
            for q_id in range(m):
                dists = [l2_distance_matrix[q_id, m_id] for m_id in M_set]
                min_dist = min(dists)
                if min_dist <= epsilon_val:
                    loss_val += query_stats[q_id]["p̂(q)"] * min_dist
                    covered_distances.append(min_dist)
                else:
                    loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["č(q)"]
                    #loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["ĉ(q)"]

            avg_dist = np.mean(covered_distances) if covered_distances else 0.0
            return loss_val, avg_dist

        def loss_lcb(M_set,epsilon_val):
            covered = set()
            loss_val = 0.0
            for q_id in range(m):
                is_covered = False
                min_dist = float('inf')
                for m_id in M_set:
                    d = l2_distance_matrix[q_id, m_id]
                    if d <= epsilon_val:
                        is_covered = True
                        min_dist = min(min_dist, d)
                if is_covered:
                    loss_val += query_stats[q_id]["p̂(q)"] * min_dist
                else:
                    loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["č(q)"]
            return loss_val


        Q_current = set(range(m))
        for _ in range(m - k):
            best_q = None
            best_loss = float("inf")
            for q in Q_current:
                temp_Q = Q_current - {q}
                temp_loss = loss_lcb(temp_Q,epsilon_val)
                if temp_loss < best_loss:
                    best_loss = temp_loss
                    best_q = q
            Q_current.remove(best_q)

        return compute_loss_and_avg_dist_actual_prob(Q_current,epsilon_val,l2_distance_matrix,query_stats_dict,id_to_query),Q_current

    # Compute loss curves
    # reverse_loss_values = [
    #     reverse_greedy_loss_for_epsilon(eps, k=5, m=m,
    #                                     query_stats=query_stats,
    #                                     l2_distance_matrix=l2_distance_matrix)
    #     for eps in epsilon_values
    # ]
    loss_eps, avg_dist_eps = [], []
    loss_rev, avg_dist_rev = [], []
    loss_lcb, avg_dist_lcb = [], []
    subopt_eps = []
    subopt_rev = []
    subopt_lcb = []

    from itertools import combinations
    def brute_force_optimal_cache(k, epsilon, l2_distance_matrix, query_stats_dict, id_to_query):
        m = len(l2_distance_matrix)
        best_loss = float("inf")
        best_cache = None
        all_indices = range(m)

        for combo in combinations(all_indices, k):
            loss_val, _ = compute_loss_and_avg_dist_actual_prob(
                M_set=combo,
                epsilon=epsilon,
                l2_distance_matrix=l2_distance_matrix,
                query_stats_dict=query_stats_dict,
                id_to_query=id_to_query
            )
            if loss_val < best_loss:
                best_loss = loss_val
                best_cache = combo

        return best_loss, set(best_cache)







    for eps in epsilon_values:
        l_eps, d_eps = epsilon_greedy_removal_loss_and_dist(eps, k, m, query_stats, l2_distance_matrix, epsilon_greedy_eps)
        (l_rev, d_rev), Q_current_ucb = reverse_greedy_loss_and_dist_for_epsilon(eps, k, m, query_stats, l2_distance_matrix)
        (l_lcb, d_lcb), Q_current_lcb = reverse_greedy_lcb_and_dist_for_epsilon(eps, k, m, query_stats, l2_distance_matrix)
        # loss_opt, Q_opt = brute_force_optimal_cache(
        # k=k,
        # epsilon=eps,
        # l2_distance_matrix=l2_distance_matrix,
        # query_stats_dict=query_stats_dict,
        # id_to_query=id_to_query
        # )
        loss_opt,Q_opt = brute_force_cache[eps]
        subopt_eps.append(l_eps-loss_opt)
        subopt_rev.append(l_rev-loss_opt)
        subopt_lcb.append(l_lcb-loss_opt)
        if Q_current_ucb != Q_current_lcb:
          print(f"Different cache at epsilon = {eps}")
        loss_eps.append(l_eps)
        avg_dist_eps.append(d_eps)
        loss_rev.append(l_rev)
        avg_dist_rev.append(d_rev)
        loss_lcb.append(l_lcb)
        avg_dist_lcb.append(d_lcb)




    cost_saved_eps = []
    cost_saved_rev = []
    cost_saved_lcb = []







    for eps in epsilon_values:
        Q_current = set(range(m))  # Start with full set

        for _ in range(m - k):
            if np.random.rand() < epsilon_greedy_eps:
                # Random removal (exploration)
                q_to_remove = random.choice(list(Q_current))
            else:
                # Greedy removal (exploitation)
                best_q = None
                best_loss = float("inf")
                for q_id in Q_current:
                    temp_Q = Q_current - {q_id}
                    loss_val = 0.0
                    for q_eval_id in range(m):
                        dists = [l2_distance_matrix[q_eval_id, m_id] for m_id in temp_Q]
                        min_dist = min(dists)
                        if min_dist <= eps:
                            loss_val += query_stats[q_eval_id]["p̂(q)"] * min_dist
                        else:
                            loss_val += query_stats[q_eval_id]["p̂(q)"] * query_stats[q_eval_id]["ĉ(q)"]
                    if loss_val < best_loss:
                        best_loss = loss_val
                        best_q = q_id
            Q_current.remove(best_q)

        eps_cache = Q_current
        metrics_eps = compute_cache_metrics(eps_cache, eps)
        cost_saved_eps.append(metrics_eps["Total Cost Saved by Cache"])

        # Reverse Greedy cache
        def loss_and_dist(M_set):
            loss_val = 0.0
            covered_distances = []
            for q_id in range(m):
                dists = [l2_distance_matrix[q_id, m_id] for m_id in M_set]
                min_dist = min(dists)
                if min_dist <= eps:
                    loss_val += query_stats[q_id]["p̂(q)"] * min_dist
                    covered_distances.append(min_dist)
                else:
                    loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["c̄(q)"]
            return loss_val, covered_distances

        Q_current = set(range(m))
        for _ in range(m - k):
            best_q = None
            best_loss = float("inf")
            for q_id in Q_current:
                temp_Q = Q_current - {q_id}
                temp_loss, _ = loss_and_dist(temp_Q)
                if temp_loss < best_loss:
                    best_loss = temp_loss
                    best_q = q_id
            Q_current.remove(best_q)

        metrics_rev = compute_cache_metrics(Q_current, eps)
        cost_saved_rev.append(metrics_rev["Total Cost Saved by Cache"])


        def loss_and_dist_lcb(M_set):
            loss_val = 0.0
            covered_distances = []
            for q_id in range(m):
                dists = [l2_distance_matrix[q_id, m_id] for m_id in M_set]
                min_dist = min(dists)
                if min_dist <= eps:
                    loss_val += query_stats[q_id]["p̂(q)"] * min_dist
                    covered_distances.append(min_dist)
                else:
                    loss_val += query_stats[q_id]["p̂(q)"] * query_stats[q_id]["č(q)"]
            return loss_val, covered_distances

        Q_current = set(range(m))
        for _ in range(m - k):
            best_q = None
            best_loss = float("inf")
            for q_id in Q_current:
                temp_Q = Q_current - {q_id}
                temp_loss, _ = loss_and_dist_lcb(temp_Q)
                if temp_loss < best_loss:
                    best_loss = temp_loss
                    best_q = q_id
            Q_current.remove(best_q)

        metrics_lcb = compute_cache_metrics(Q_current, eps)
        cost_saved_lcb.append(metrics_lcb["Total Cost Saved by Cache"])
        return loss_eps, avg_dist_eps, loss_rev, avg_dist_rev, loss_lcb, avg_dist_lcb, cost_saved_eps, cost_saved_rev, cost_saved_lcb,subopt_eps,subopt_rev,subopt_lcb

import numpy as np
import matplotlib.pyplot as plt
n_runs = 10
all_loss_eps, all_loss_rev, all_loss_lcb = [], [], []
all_subopt_eps, all_subopt_rev, all_subopt_lcb = [], [], []
all_dist_eps, all_dist_rev, all_dist_lcb = [], [], []
all_cost_saved_eps, all_cost_saved_rev, all_cost_saved_lcb = [], [], []
epsilon_values = np.linspace(0, 1, 100)
for i in range(n_runs):
    loss_eps, avg_dist_eps, loss_rev, avg_dist_rev, loss_lcb, avg_dist_lcb,cost_saved_eps, cost_saved_rev, cost_saved_lcb,subopt_eps,subopt_rev,subopt_lcb = run_single_trial(k=k,t=t,delta=delta,epsilon_greedy_eps=epsilon_greedy_eps,seed=i)

    all_loss_eps.append(loss_eps)
    all_loss_rev.append(loss_rev)
    all_loss_lcb.append(loss_lcb)
    all_subopt_eps.append(subopt_eps)
    all_subopt_rev.append(subopt_rev)
    all_subopt_lcb.append(subopt_lcb)
    all_dist_eps.append(avg_dist_eps)
    all_dist_rev.append(avg_dist_rev)
    all_dist_lcb.append(avg_dist_lcb)
    all_cost_saved_eps.append(cost_saved_eps)
    all_cost_saved_rev.append(cost_saved_rev)
    all_cost_saved_lcb.append(cost_saved_lcb)

loss_eps_mean = np.mean(all_loss_eps, axis=0)
loss_eps_std  = np.std(all_loss_eps, axis=0)
loss_rev_mean = np.mean(all_loss_rev, axis=0)
loss_rev_std  = np.std(all_loss_rev, axis=0)
loss_lcb_mean = np.mean(all_loss_lcb, axis=0)
loss_lcb_std  = np.std(all_loss_lcb, axis=0)
subopt_eps_mean = np.mean(all_subopt_eps, axis=0)
subopt_eps_std  = np.std(all_subopt_eps, axis=0)
subopt_rev_mean = np.mean(all_subopt_rev, axis=0)
subopt_rev_std  = np.std(all_subopt_rev, axis=0)
subopt_lcb_mean = np.mean(all_subopt_lcb, axis=0)
subopt_lcb_std  = np.std(all_subopt_lcb, axis=0)
avg_dist_eps_mean = np.mean(all_dist_eps, axis=0)
avg_dist_eps_std  = np.std(all_dist_eps, axis=0)
avg_dist_rev_mean = np.mean(all_dist_rev, axis=0)
avg_dist_rev_std  = np.std(all_dist_rev, axis=0)
avg_dist_lcb_mean = np.mean(all_dist_lcb, axis=0)
avg_dist_lcb_std  = np.std(all_dist_lcb, axis=0)
cost_saved_eps_mean = np.mean(all_cost_saved_eps, axis=0)
cost_saved_eps_std  = np.std(all_cost_saved_eps, axis=0)
cost_saved_rev_mean = np.mean(all_cost_saved_rev, axis=0)
cost_saved_rev_std  = np.std(all_cost_saved_rev, axis=0)
cost_saved_lcb_mean = np.mean(all_cost_saved_lcb, axis=0)
cost_saved_lcb_std  = np.std(all_cost_saved_lcb, axis=0)

# Plot both
plt.figure(figsize=(8, 5))
plt.plot(epsilon_values, loss_eps_mean, marker='o', label='Epsilon-Greedy')
plt.fill_between(epsilon_values, loss_eps_mean - loss_eps_std, loss_eps_mean + loss_eps_std, alpha=0.2)
plt.plot(epsilon_values, loss_rev_mean, marker='s', label='Reverse Greedy (UCB)')
plt.fill_between(epsilon_values, loss_rev_mean - loss_rev_std, loss_rev_mean + loss_rev_std, alpha=0.2)
plt.plot(epsilon_values, loss_lcb_mean, marker='x', label='Reverse Greedy (LCB)')
plt.fill_between(epsilon_values, loss_lcb_mean - loss_lcb_std, loss_lcb_mean + loss_lcb_std, alpha=0.2)
plt.xlabel("Epsilon Value")
plt.ylabel("Loss")
plt.title("Loss vs Epsilon: Epsilon-Greedy vs Reverse Greedy")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 5))
plt.plot(epsilon_values, avg_dist_eps_mean, marker='o', label='Epsilon-Greedy')
plt.fill_between(epsilon_values, avg_dist_eps_mean - avg_dist_eps_std, avg_dist_eps_mean + avg_dist_eps_std, alpha=0.2)
plt.plot(epsilon_values, avg_dist_rev_mean, marker='s', label='Reverse Greedy (UCB)')
plt.fill_between(epsilon_values, avg_dist_rev_mean - avg_dist_rev_std, avg_dist_rev_mean + avg_dist_rev_std, alpha=0.2)
plt.plot(epsilon_values, avg_dist_lcb_mean, marker='x', label='Reverse Greedy (LCB)')
plt.fill_between(epsilon_values, avg_dist_lcb_mean - avg_dist_lcb_std, avg_dist_lcb_mean + avg_dist_lcb_std, alpha=0.2)
plt.xlabel("Epsilon Value")
plt.ylabel("Average Distance of Covered Queries")
plt.title("Avg Distance of Covered Queries vs Epsilon")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 5))
plt.plot(epsilon_values, cost_saved_eps_mean, marker='o', label='Epsilon-Greedy')
plt.fill_between(epsilon_values, cost_saved_eps_mean - cost_saved_eps_std, cost_saved_eps_mean + cost_saved_eps_std, alpha=0.2)
plt.plot(epsilon_values, cost_saved_rev_mean, marker='s', label='Reverse Greedy (UCB)')
plt.fill_between(epsilon_values, cost_saved_rev_mean - cost_saved_rev_std, cost_saved_rev_mean + cost_saved_rev_std, alpha=0.2)
plt.plot(epsilon_values, cost_saved_lcb_mean, marker='x', label='Reverse Greedy (LCB)')
plt.fill_between(epsilon_values, cost_saved_lcb_mean - cost_saved_lcb_std, cost_saved_lcb_mean + cost_saved_lcb_std, alpha=0.2)
plt.xlabel("Epsilon Value")
plt.ylabel("Total Cost Saved by Cache")
plt.title("Cost Saved vs Epsilon")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 5))
plt.plot(epsilon_values, subopt_eps_mean, label='Epsilon-Greedy Gap')
plt.fill_between(epsilon_values, subopt_eps_mean - subopt_eps_std, subopt_eps_mean + subopt_eps_std, alpha=0.2)
plt.plot(epsilon_values, subopt_rev_mean, label='Reverse Greedy (UCB) Gap')
plt.fill_between(epsilon_values, subopt_rev_mean - subopt_rev_std, subopt_rev_mean + subopt_rev_std, alpha=0.2)
plt.plot(epsilon_values, subopt_lcb_mean, label='Reverse Greedy (LCB) Gap')
plt.fill_between(epsilon_values, subopt_lcb_mean - subopt_lcb_std, subopt_lcb_mean + subopt_lcb_std, alpha=0.2)
plt.xlabel("Epsilon")
plt.ylabel("Suboptimality Gap")
plt.title("Gap from Brute Force Optimal Cache")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()