import json, random, math
import torch
import numpy as np
from collections import defaultdict
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

# -------- Configuration --------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
phi = lambda t: math.log(t + 1)
K_MAX = 500

# -------- KNN-UCB Bandit (cosine distance) --------
class KNN_UCB_Bandit:
    def __init__(self, models, embeddings, theta=1.0, beta=1, distance="cosine"):
        """
        models: list of model names
        embeddings: torch.Tensor of shape (num_prompts, embedding_dim)
        distance: "cosine" or "L2"
        """
        self.models = models
        self.emb = embeddings
        self.history = []  # list of tuples (prompt_idx, model_name, reward)
        self.theta = theta
        self.beta = beta
        self.distance = distance

    # --- à ajouter dans classes/active_knn_ucb.py ---

    def _ucb_per_model(self, x_idx, t):
        """
        Computes (ucb, bonus, var) for each model at time t.
        Returns a list [(model, ucb, bonus, var)], unsorted.
        """
        if not self.history:
            # Tous identiques si aucun historique : UCB = +inf pour forcer exploration
            return [(m, float('inf'), float('inf'), float('inf')) for m in self.models]

        x = self.emb[x_idx : x_idx + 1]
        hist_idxs = [i for (i, _, _) in self.history]
        hist_x = self.emb[hist_idxs]
        
        # Compute distances based on the selected metric
        if self.distance == "L2":
            # L2 distance: ||x - hist_x||_2
            dists = torch.norm(x - hist_x, p=2, dim=1)
        else:  # cosine (default)
            cos_sim = torch.nn.functional.cosine_similarity(x, hist_x, dim=1)
            dists = 1.0 - cos_sim

        k_max = min(len(dists), K_MAX)
        vals, idxs = torch.topk(dists, k_max, largest=False)
        topk_records = [(self.history[i], float(vals[j])) for j, i in enumerate(idxs.tolist())]

        buckets = {m: [] for m in self.models}
        for ((pidx, m, r), dist) in topk_records:
            buckets[m].append((r, dist))

        stats = []
        logt = math.log(t + 1)
        for m in self.models:
            pairs = buckets[m]
            if not pairs:
                stats.append((m, float('inf'), float('inf'), 0.0))
                continue

            pairs.sort(key=lambda x: x[1])
            best_ucb_m = -float('inf')
            best_bonus_m = float('inf')
            best_k = 0
            best_k_scores = []
            weighted_sum = 0.0

            for k in range(1, len(pairs) + 1):
                r_val, dist = pairs[k - 1]
                weighted_sum += r_val

                mu_hat = weighted_sum / k
                bonus = math.sqrt(self.theta * logt / k) + phi(t) * self.beta * dist
                ucb = mu_hat + bonus

                if bonus < best_bonus_m:
                    best_ucb_m = ucb
                    best_bonus_m = bonus
                    best_k = k
                    best_k_scores = [pairs[j][0] for j in range(k)]

            var_k = float(np.var(best_k_scores)) if best_k_scores else 0.0
            stats.append((m, best_ucb_m, best_bonus_m, var_k))

        return stats

    def rank_arms(self, x_idx, t):
        """
        Retourne la liste triée décroissante des modèles par UCB :
        [(model, ucb, bonus, var), ...], du meilleur au pire.
        """
        stats = self._ucb_per_model(x_idx, t)
        stats.sort(key=lambda z: z[1], reverse=True)
        return stats

    def select_arm(self, x_idx, t):
        """
        Signature d'origine conservée.
        Retour: best_model, delta_ucb, best_ucb, best_var, second_model
        """
        stats = self.rank_arms(x_idx, t)
        best_model, best_ucb, best_bonus, best_var = stats[0]
        if len(stats) > 1:
            second_model, second_ucb, _, _ = stats[1]
        else:
            second_model, second_ucb = None, -float('inf')
        delta_ucb = best_ucb - second_ucb
        return best_model, delta_ucb, best_bonus, best_var, second_model

    def update(self, x_idx, arm, reward):
        """
        Append a new observation (prompt_idx, model_name, reward).
        """
        self.history.append((x_idx, arm, reward))
