import random
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
from sklearn.cluster import KMeans
from scipy.stats import norm
from core.infra.llm import LLMApi
from core.model.cluster_arm import ClusterArm

class Optimizer:
    def __init__(self, evaluator):
        self.evaluator = evaluator
        self.evaluated_prompts = []
        self.arms = [ClusterArm(i) for i in range(5)]  # 初始化5个ClusterArm
        self.total_pulls = 0
        self.embedding_cache = {}

    def compute_embedding(self, prompt):
        """
        计算提示文本的嵌入，缓存计算过的嵌入以提高效率
        """
        llm_api = LLMApi()
        if prompt.text in self.embedding_cache:
            prompt.embedding = self.embedding_cache[prompt.text]
        else:
            prompt.embedding = llm_api.get_embedding(prompt.text)
            self.embedding_cache[prompt.text] = prompt.embedding
        return prompt.embedding

    def bayesian_select(self, candidate_prompts, min_samples=10, top_n=5):
        """
        使用贝叶斯方法选择最优提示
        """
        if len(self.evaluated_prompts) < min_samples:
            # 返回随机 N 个提示
            return random.sample(candidate_prompts, k=min(top_n, len(candidate_prompts)))
        
        X = np.array([p[0] for p in self.evaluated_prompts])
        Y = np.array([p[1] for p in self.evaluated_prompts]).reshape(-1, 1)
        gp = GaussianProcessRegressor(kernel=RBF())
        gp.fit(X, Y)
        
        embeddings = [self.compute_embedding(p) for p in candidate_prompts]
        ei = self.expected_improvement(embeddings, gp)
        
        # 选择 EI 最大的前 top_n 个提示
        top_indices = np.argsort(ei)[-top_n:][::-1]
        return [candidate_prompts[i] for i in top_indices]

    def expected_improvement(self, X, gp, xi=0.01):
        """
        计算期望改进（Expected Improvement）
        """
        mu, sigma = gp.predict(X, return_std=True)
        mu = mu.ravel()
        sigma = sigma.ravel()
        mu_best = max([p[1] for p in self.evaluated_prompts])
        
        with np.errstate(divide='warn'):
            Z = (mu - mu_best - xi) / sigma
            ei = (mu - mu_best - xi) * norm.cdf(Z) + sigma * norm.pdf(Z)
            ei[sigma == 0.0] = 0.0
        return ei

    def mab_select(self, candidate_prompts, min_candidates=10, top_n=5):
        """
        使用多臂老虎机方法选择最优提示
        """
        if len(candidate_prompts) < min_candidates:
            return random.sample(candidate_prompts, k=min(top_n, len(candidate_prompts)))
        
        embeddings = []
        for p in candidate_prompts:
            emb = self.compute_embedding(p)
            if emb is None or len(emb) != 2560:
                print(f"Invalid embedding for prompt: {p.text[:50]}... Using zeros.")
                emb = np.zeros(2560, dtype=np.float32)
            embeddings.append(emb)

        embeddings_array = np.array(embeddings)
        unique_embeddings, indices = np.unique(embeddings_array, axis=0, return_index=True)
        
        n_clusters = 5 if len(unique_embeddings) >= 5 else max(1, len(unique_embeddings))

        kmeans = KMeans(n_clusters=n_clusters, init='k-means++')
        labels = kmeans.fit_predict(embeddings_array)
        
        scores = [arm.ucb_score(self.total_pulls) for arm in self.arms[:n_clusters]]
        top_clusters = np.argsort(scores)[-top_n:][::-1]
        
        selected_prompts = []
        for selected_cluster in top_clusters:
            cluster_prompts = [p for p, l in zip(candidate_prompts, labels) if l == selected_cluster]
            if cluster_prompts:
                chosen = random.choice(cluster_prompts)
                chosen.cluster_id = selected_cluster
                selected_prompts.append(chosen)
        
        # 如果选中的提示不足 top_n，补齐
        if len(selected_prompts) < top_n:
            leftovers = [p for p in candidate_prompts if p not in selected_prompts and p.text not in [sp.text for sp in selected_prompts]]
            to_add = random.sample(leftovers, k=min(len(leftovers), top_n - len(selected_prompts)))
            selected_prompts.extend(to_add)
        
        return selected_prompts