# -*- coding: utf-8 -*-
"""
Created on Wed Jul  9 15:09:48 2025

@author: baran
"""

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 29 16:35:08 2025

@author: baran
"""

import random
import pandas as pd
import numpy as np
import math
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from collections import defaultdict

model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
import defaultdict
# Define 20 distinct queries in related clusters
distinct_queries = [
    "weather",
    "weather in New York City for the weekend",
    "NBA",
    "latest updates on the NBA playoffs",
    "AAPL",
    "current stock price and news for Apple Inc",
    "cheap flights",
    "compare flight prices from NYC to Paris",
    "Python",
    "comprehensive Python tutorial for beginners",
    "bitcoin",
    "current trends and forecasts for Bitcoin",
    "tennis scores",
    "who won the latest tennis grand slam final",
    "AI papers",
    "recent breakthroughs in AI and machine learning research",
    "Rome attractions",
    "top 10 tourist attractions and hidden gems in Rome",
    "guitar lessons online",
    "how to learn to play acoustic guitar from scratch online",
]

assert len(distinct_queries) == 20
no_rounds = 100
query_stream = [random.choice(distinct_queries) for _ in range(no_rounds)]
cache_size = 5
k = cache_size
delta = 0.8
m = len(distinct_queries)
epsilon_greedy_eps = 0.1
query_to_id = {q: i for i, q in enumerate(distinct_queries)}
id_to_query = {i: q for q, i in query_to_id.items()}
query_stream_ids = [query_to_id[q] for q in query_stream]
emb_stream = model.encode(query_stream)
emb_queries = model.encode(distinct_queries)
query_to_emb = dict(zip(distinct_queries, emb_queries))

from sklearn.metrics.pairwise import euclidean_distances
l2_distance_matrix = euclidean_distances(emb_queries)
l2_df = pd.DataFrame(
    l2_distance_matrix,
    index=distinct_queries,
    columns=distinct_queries
)
min_val = np.min(l2_distance_matrix)
max_val = np.max(l2_distance_matrix)
normalized_l2 = (l2_distance_matrix - min_val) / (max_val - min_val)
l2_distance_matrix = normalized_l2
token_lengths = [len(tokenizer.encode(q)) for q in query_stream]
min_len = min(token_lengths)
max_len = max(token_lengths)
costs_stream = [(l - min_len) / (max_len - min_len) for l in token_lengths]
# N_p, N_c, L_c, cache_history_ids, true_cost_dict, C_qt = defaultdict(int), defaultdict(int), defaultdict(float), [], {}, []
# true_cost_dict = {}
p_true = 1 / m
sigma = 0.05
current_cache_ids = set(random.sample(range(len(distinct_queries)), cache_size))
epsilon_values = np.linspace(0, 1.0, 50)
loss_per_epsilon = []




loss_over_time_per_epsilon= []
for epsilon in epsilon_values:
    N_counts = defaultdict(int)
    Nc_counts = defaultdict(int)
    L_costs = defaultdict(float)
    phat_dict = defaultdict(float)
    chat_dict = defaultdict(float)
    lcb_dict = defaultdict(float)
    true_cost_dict = defaultdict(list)
    cache_history = []
    C_qt = []
    def update_estimates():
        for i in range(m):
            phat_dict[i] = N_counts[i] / (sum(N_counts.values()) + 1e-8)
            if Nc_counts[i] > 0:
                chat_dict[i] = L_costs[i] / Nc_counts[i]
                bonus = 0.05 * math.sqrt(2 * math.log(4 * m * no_rounds / delta) / Nc_counts[i])
                lcb_dict[i] = max(chat_dict[i] - bonus, 0.0)
            else:
                chat_dict[i] = 0.0
                lcb_dict[i] = 0.0

    def reverse_greedy(k, epsilon):
        Q = set(range(m))
        for _ in range(m - k):
            best_q = None
            best_loss = float("inf")
            for q in Q:
                temp = Q - {q}
                loss = 0.0
                for i in range(m):
                    min_dist = min(l2_distance_matrix[i][j] for j in temp)
                    if min_dist <= epsilon:
                        loss += phat_dict[i] * min_dist
                    else:
                        loss += phat_dict[i] * lcb_dict[i]
                if loss < best_loss:
                    best_loss = loss
                    best_q = q
            Q.remove(best_q)
        return Q

    for t in range(no_rounds):
        q = query_stream_ids[t]
        true_cost = costs_stream[t]
        true_cost_dict[q].append(true_cost)
        N_counts[q] += 1
    
        update_estimates()
        epsilon_t = epsilon  # can be varied per round
    
        cache = reverse_greedy(cache_size, epsilon_t)
        cache_history.append(cache)
        dists = [l2_distance_matrix[q, c] for c in cache]
        min_dist = min(dists)
        if min_dist > lcb_dict[q]:
            Nc_counts[q] += 1
            noisy_cost = np.clip(np.random.normal(loc=true_cost, scale=sigma), 0.0, 1.0)
            L_costs[q] += noisy_cost
            C_qt.append(noisy_cost)
    
        else:
            C_qt.append(0.0)
    c_true = {q: np.mean(vals) for q, vals in true_cost_dict.items()}
    def compute_loss_actual(cache_set, epsilon):
        loss = 0.0
        for i in range(m):
            min_dist = min([l2_distance_matrix[i][j] for j in cache_set])
            if min_dist <= epsilon:
                loss += p_true[i] * min_dist
            else:
                loss += p_true[i] * c_true[i]
        return loss
    loss_val = compute_loss_actual(cache, epsilon_t)
    loss_over_time_per_epsilon.append(loss_val)

import matplotlib.pyplot as plt
plt.plot(epsilon_values, loss_per_epsilon, marker='o')
plt.xlabel("Epsilon (ε)")
plt.ylabel("Total Loss")
plt.title("Loss vs Epsilon for Online Reverse Greedy")
plt.grid(True)
plt.show()



# for t in range(no_rounds):
#     q = query_stream_ids[t]
#     true_cost = costs_stream[t]
#     N_counts[q] += 1

#     update_estimates()
#     epsilon_t = 0.1  # can be varied per round

#     cache = reverse_greedy(cache_size, epsilon_t)
#     cache_history.append(cache)

#     if q in cache:
#         C_qt.append(0.0)
#     else:
#         Nc_counts[q] += 1
#         noisy_cost = np.clip(np.random.normal(loc=true_cost, scale=sigma), 0.0, 1.0)
#         L_costs[q] += noisy_cost
#         C_qt.append(noisy_cost)
# for t in range(no_rounds):
#     q_id = query_stream_ids[t]
#     q = id_to_query[q_id]
#     cost = costs_stream[t]
#     if q_id not in true_cost_dict:
#         true_cost_dict[q_id] = cost
#     if q_id not in current_cache_ids:
#         noisy_cost = np.clip(np.random.normal(loc=cost, scale=sigma), 0.0, 1.0)
#         C_qt.append(noisy_cost)
#         N_c[q_id] += 1
#         L_c[q_id] += noisy_cost
#     else:
#         C_qt.append(0.0)
#     # Update p-hat and c-hat
#     N_p[q_id] += 1
#     chat_dict = {q: (L_c[q] / N_c[q]) if N_c[q] > 0 else 1.0 for q in range(m)}
#     phat_dict = {q: N_p[q] / (t + 1) for q in range(m)}
    
    
#     #chat = L_c[q] / N_c[q] if N_c[q] > 0 else 0.0
#     bonus = 0.05 * math.sqrt(2 * math.log(4 * len(distinct_queries) * no_rounds / delta) / N_c[q]) if N_c[q] > 0 else 1.0
#     lcb = max(chat - bonus, 0.0)
#     ucb = chat + bonus






