import os
import json
import pickle
import random
import torch
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
from classes.active_knn_ucb import KNN_UCB_Bandit

# -------- Configuration --------
dataset        = "carrot-bowl"
image_reward = False
distance = "cosine"  # "cosine" or "L2"
if image_reward:
    metadata_path  = f"../datasets/{dataset}/metadata_IR.json"
else:
    metadata_path  = f"../datasets/{dataset}/metadata.json"
selected_models= ["Sana", "Unidiffuser", "LCM", "Koala", "SDXL-Turbo", "SSD-1B"]
baseline_model = "SSD-1B"
assert baseline_model in selected_models

T         = 2000
BUDGET20  = int(0.2*T)

epsilon_balrog = 0.22
THETA        = 0.5
num_runs     = 1
generations  = 5
max_prompts  = 10000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

clip_model     = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def get_prompt_embedding(prompt, cache={}):
    if prompt not in cache:
        inp = clip_processor(text=[prompt], return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            feat = clip_model.get_text_features(**inp)
        cache[prompt] = (feat / feat.norm(dim=-1, keepdim=True)).squeeze(0)
    return cache[prompt]

def load_data(path, max_load):
    with open(path, "r", encoding="utf-8") as f:
        raw = json.load(f)
    scores_map = defaultdict(lambda: defaultdict(list))
    for e in raw:
        if image_reward:
            p, m, cs = e["prompt"], e["model"], e.get("image_reward_scores", [])
        else:
            p, m, cs = e["prompt"], e["model"], e.get("clip_scores", [])
        if m in selected_models and isinstance(cs, list) and all(isinstance(x, float) for x in cs):
            scores_map[p][m].extend(cs)
    valid = [p for p in scores_map if all(scores_map[p][m] for m in selected_models)]
    prompts = random.sample(valid, min(max_load, len(valid)))
    embeddings = {p: get_prompt_embedding(p) for p in tqdm(prompts, desc="Embedding prompts")}
    X = torch.stack([embeddings[p] for p in prompts], dim=0)
    return prompts, scores_map, embeddings, X

def get_reward_estimates(bandit, x_idx, t):
    """Get the estimated reward for each model based on K-nearest neighbors"""
    import math
    
    if not bandit.history:
        return {m: 0.0 for m in bandit.models}
    
    x = bandit.emb[x_idx : x_idx + 1]
    hist_idxs = [i for (i, _, _) in bandit.history]
    hist_x = bandit.emb[hist_idxs]
    
    # Compute distances
    if bandit.distance == "L2":
        dists = torch.norm(x - hist_x, p=2, dim=1)
    else:
        cos_sim = torch.nn.functional.cosine_similarity(x, hist_x, dim=1)
        dists = 1.0 - cos_sim
    
    k_max = min(len(dists), 500)  # K_MAX
    vals, idxs = torch.topk(dists, k_max, largest=False)
    topk_records = [(bandit.history[i], float(vals[j])) for j, i in enumerate(idxs.tolist())]
    
    estimates = {}
    logt = math.log(t + 1)
    phi = lambda t: math.log(t + 1)
    
    for model in bandit.models:
        # Filter for this model
        pairs = [(r, dist) for ((pidx, m, r), dist) in topk_records if m == model]
        
        if not pairs:
            estimates[model] = 0.0
            continue
        
        pairs.sort(key=lambda x: x[1])
        best_bonus_m = float('inf')
        best_estimate = 0.0
        weighted_sum = 0.0
        
        for k in range(1, len(pairs) + 1):
            r_val, dist = pairs[k - 1]
            weighted_sum += r_val
            
            mu_hat = weighted_sum / k
            bonus = math.sqrt(bandit.theta * logt / k) + phi(t) * bandit.beta * dist
            
            if bonus < best_bonus_m:
                best_bonus_m = bonus
                best_estimate = mu_hat
        
        estimates[model] = best_estimate
    
    return estimates

def single_run(prompts, scores_map, embeddings, X, algo_name):
    N = len(prompts)
    estimation_errors = {m: [] for m in selected_models}  # Store errors per model over time
    
    if algo_name == "BALROG 20":
        bandit = KNN_UCB_Bandit(selected_models, X, theta=THETA, distance=distance)
        budget_remaining = BUDGET20
    else:  # KNN-UCB
        bandit = KNN_UCB_Bandit(selected_models, X, theta=4.0, distance=distance)
    
    if T <= N:
        idxs = random.sample(range(N), T)
    else:
        idxs = random.choices(range(N), k=T)

    for t in tqdm(range(1, T+1), desc=f"{algo_name}"):
        i = idxs[t-1]
        p = prompts[i]
        emb = embeddings[p]

        samples = {m: random.sample(scores_map[p][m], min(len(scores_map[p][m]), generations)) for m in selected_models}
        smap = {m: float(torch.tensor(samples[m], device=device).float().mean()) for m in selected_models}

        # Get estimated rewards for all models
        estimates = get_reward_estimates(bandit, i, t)
        
        # Calculate absolute estimation errors for all models
        for model in selected_models:
            true_reward = smap[model]
            estimated_reward = estimates[model]
            error = abs(estimated_reward - true_reward)
            estimation_errors[model].append(error)
        
        # Select arm
        arm, delta_k, _, _, _ = bandit.select_arm(i, t)
        
        if algo_name == "BALROG 20":
            # BALROG logic with active learning
            if delta_k < epsilon_balrog and budget_remaining > 0:
                for m in selected_models:
                    bandit.update(i, m, smap[m])
                budget_remaining -= 1
            else:
                bandit.update(i, arm, smap[arm])
        else:
            # KNN-UCB: passive only
            bandit.update(i, arm, smap[arm])

    return estimation_errors


# -------- MAIN --------
if __name__ == "__main__":
    prompts, scores_map, embeddings, X = load_data(metadata_path, max_prompts)
    
    algos = ["BALROG 20", "KNN-UCB"]
    all_errors = {algo: {m: [] for m in selected_models} for algo in algos}
    
    for algo in algos:
        print(f"\n{'='*60}")
        print(f"Running {algo}")
        print(f"{'='*60}")
        
        for run in range(num_runs):
            print(f"Run {run+1}/{num_runs}")
            errors = single_run(prompts, scores_map, embeddings, X, algo)
            for model in selected_models:
                all_errors[algo][model].append(errors[model])  # Keep as list per run
    
    # Save results
    os.makedirs("results/intermediate_results/data", exist_ok=True)
    save_path = os.path.join(
        "results",
        "intermediate_results",
        "data",
        f"estimation_errors_{dataset}_{T}_{num_runs}runs.pkl"
    )
    with open(save_path, "wb") as f:
        pickle.dump(all_errors, f)
    print(f"\nSaved estimation errors to {save_path}")
    
    # Compute average errors over runs for each model
    avg_errors = {}
    for algo in algos:
        avg_errors[algo] = {}
        for model in selected_models:
            avg_errors[algo][model] = np.mean(np.stack(all_errors[algo][model]), axis=0)
    
    # Compute average across all models at each iteration
    avg_across_models = {}
    for algo in algos:
        # Stack all model errors
        all_model_errors = np.stack([avg_errors[algo][m] for m in selected_models], axis=0)
        # Average across models at each iteration
        avg_across_models[algo] = np.mean(all_model_errors, axis=0)
    
    # Apply sliding window average
    window = 500
    sliding_avg_errors = {}
    for algo in algos:
        errors = avg_across_models[algo]
        if len(errors) >= window:
            sliding_avg = np.convolve(errors, np.ones(window)/window, mode="valid")
            sliding_avg_errors[algo] = sliding_avg
        else:
            sliding_avg_errors[algo] = errors
    
    # Plotting
    plt.rcParams.update({
        'font.size': 16,
        'axes.titlesize': 18,
        'axes.labelsize': 16,
        'xtick.labelsize': 14,
        'ytick.labelsize': 14,
        'legend.fontsize': 14
    })
    
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    
    colors = {"BALROG 20": "indigo", "KNN-UCB": "magenta"}
    linestyles = {"BALROG 20": "-", "KNN-UCB": "--"}
    
    for algo in algos:
        errors = sliding_avg_errors[algo]
        start_iter = window if len(avg_across_models[algo]) >= window else 1
        ax.plot(np.arange(start_iter, start_iter + len(errors)), errors, 
               color=colors[algo], linestyle=linestyles[algo], 
               linewidth=2.5, label=algo)
    
    ax.set_xlabel('Iteration')
    ax.set_ylabel(f'{window}-Sliding Avg Absolute Error (Avg over Models)')
    ax.set_title('Estimation Error Comparison')
    ax.grid(True, alpha=0.3)
    ax.legend(loc='best')
    
    plt.tight_layout()
    os.makedirs("results/intermediate_results", exist_ok=True)
    plt.savefig(f"results/intermediate_results/estimation_errors_{dataset}_{T}_{num_runs}runs.pdf", dpi=600, bbox_inches="tight")
    print(f"\nSaved plot to results/intermediate_results/estimation_errors_{dataset}_{T}_{num_runs}runs.pdf")

