# model_removal.py
import os
import json
import random
import math
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
from classes.pakucb import PAKUCB
from classes.active_knn_ucb import KNN_UCB_Bandit
from classes.neuronal_s_nets import ExploitationNet, ExplorationNet
import pickle

# -------- Configuration --------
dataset         = "carrot-bowl"
metadata_path   = f"../datasets/{dataset}/metadata.json"

# Start with complete pool; remove 2 models during run
selected_models = ["Unidiffuser", "LCM", "SSD-1B", "SDXL-Turbo", "Sana", "Koala"]
baseline_model  = "SSD-1B"   # baseline reference for O2B
assert baseline_model in selected_models

T           = 2000
BUDGET      = T // 3         # for BALROG
BUDGET_NS   = BUDGET         # for neuronal-s (if enabled)

epsilon     = 0.45
THETA       = 1.0
num_runs    = 10
generations = 5
max_prompts = T

hidden_dim_ns = 256
lr_ns         = 1e-5
gamma_ns      = 2.0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -------- CLIP Setup --------
clip_model     = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def get_prompt_embedding(prompt, cache={}):
    if prompt not in cache:
        inp = clip_processor(text=[prompt], return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            feat = clip_model.get_text_features(**inp)
        cache[prompt] = (feat / feat.norm(dim=-1, keepdim=True)).squeeze(0)
    return cache[prompt]

def load_data(path, max_load):
    with open(path, "r", encoding="utf-8") as f:
        raw = json.load(f)
    scores_map = defaultdict(lambda: defaultdict(list))
    for e in raw:
        p, m, cs = e["prompt"], e["model"], e.get("clip_scores", [])
        if m in selected_models and cs:
            scores_map[p][m].extend(cs)
    # conserver seulement les prompts où tous les modèles ont des scores
    valid = [p for p in scores_map if all(scores_map[p][m] for m in selected_models)]
    prompts = random.sample(valid, min(max_load, len(valid)))
    embeddings = {p: get_prompt_embedding(p) for p in tqdm(prompts, desc="Embedding prompts")}
    X = torch.stack([embeddings[p] for p in prompts], dim=0)
    return prompts, scores_map, embeddings, X

def sample_mean(lst, k):
    if not lst: return 0.0
    return float(torch.tensor(random.sample(lst, min(len(lst), k)), device=device).float().mean())

def single_run(prompts, scores_map, embeddings, X, algos):
    N = len(prompts)
    o2b = {a: [] for a in algos}
    opr = {a: [] for a in algos if a!="Optimal"}
    budgets = {a: [] for a in ["BALROG","KNN-UCB","neuronal-s"] if a in algos}

    # History for replay when pool changes
    hist_pb  = []  # tuples (emb, reward, model_name)
    hist_nb  = []  # tuples (i, model_idx_name, reward) -- store model name for proper replay
    hist_knn = []  # same

    # Start with full pool
    current = selected_models.copy()

    remove_1 = "Unidiffuser"
    remove_2 = "SSD-1B"
    assert remove_1 in current and remove_2 in current 

    # Instanciation bandits
    pb  = PAKUCB(current)
    nb  = KNN_UCB_Bandit(current, X, theta=THETA)
    knn = KNN_UCB_Bandit(current, X, theta=THETA)

    # neuronal-s (optional)
    if "neuronal-s" in algos:
        f1 = ExploitationNet(X.shape[1], hidden_dim_ns, len(current)).to(device)
        f2 = ExplorationNet(hidden_dim_ns + len(current)*hidden_dim_ns + len(current),
                             hidden_dim_ns, len(current)).to(device)
        opt1 = torch.optim.Adam(f1.parameters(), lr=lr_ns)
        opt2 = torch.optim.Adam(f2.parameters(), lr=lr_ns)
        budget_ns = BUDGET_NS

        def compute_phi(x):
            x0 = x.unsqueeze(0)
            logits,h1 = f1(x0)
            grads=[]
            for k in range(len(current)):
                f1.zero_grad()
                gW,gB = torch.autograd.grad(logits[0,k],[f1.fc2.weight,f1.fc2.bias],retain_graph=True)
                grads.append(torch.cat([gW.flatten(),gB.flatten()]))
            return torch.cat([h1.flatten(), torch.stack(grads).mean(0)],dim=0)

    budget_knn = BUDGET
    budget_kl  = BUDGET  # (pour traçage KNN si besoin d'un budget ; ici on ne le dépense pas)
    removed1 = removed2 = False

    def rebuild_after_change():
        """Reinstantiate bandits with new 'current' and replay filtered history."""
        nonlocal pb, nb, knn, f1, f2, opt1, opt2
        # PAK-UCB
        pb = PAKUCB(current)
        for emb, rw, m in hist_pb:
            if m in current:
                pb.update(emb, rw, m)
        # BALROG/KNN-UCB (rebuild + replay)
        nb = KNN_UCB_Bandit(current, X, theta=THETA)
        for i, m_name, rw in hist_nb:
            if m_name in current:
                nb.update(i, m_name, rw)
        knn = KNN_UCB_Bandit(current, X, theta=THETA)
        for i, m_name, rw in hist_knn:
            if m_name in current:
                knn.update(i, m_name, rw)
        # neuronal-s : on réinitialise l’output dimension (dépend de len(current))
        if "neuronal-s" in algos:
            f1 = ExploitationNet(X.shape[1], hidden_dim_ns, len(current)).to(device)
            f2 = ExplorationNet(hidden_dim_ns + len(current)*hidden_dim_ns + len(current),
                                 hidden_dim_ns, len(current)).to(device)
            opt1 = torch.optim.Adam(f1.parameters(), lr=lr_ns)
            opt2 = torch.optim.Adam(f2.parameters(), lr=lr_ns)

    for t in tqdm(range(1, T+1), desc="Test iterations (model removal)"):
        # Remove first model at T/3
        if not removed1 and t > T//3:
            if remove_1 in current:
                current.remove(remove_1)
                rebuild_after_change()
            removed1 = True

        # Remove second model at 2T/3
        if not removed2 and t > 2*T//3:
            if remove_2 in current:
                current.remove(remove_2)
                rebuild_after_change()
            removed2 = True

        i = random.randrange(N)
        p = prompts[i]
        emb = embeddings[p]
        # Build score map for current models + baseline (even if baseline no longer in current)
        smap = {m: sample_mean(scores_map[p][m], generations) for m in set(current + [baseline_model])}
        base = smap[baseline_model]
        # Best among remaining models
        best_m = max(current, key=lambda m: smap[m])
        best_s = smap[best_m]

        # Optimal
        if "Optimal" in algos:
            o2b["Optimal"].append(best_s - base)

        # Always (always baseline)
        if "Always" in algos:
            o2b["Always"].append(0.0)
            opr["Always"].append(int(baseline_model == best_m))

        # Random
        if "Random" in algos:
            r = random.choice(current)
            o2b["Random"].append(smap[r] - base)
            opr["Random"].append(int(r == best_m))

        # PAK-UCB
        if "PAK-UCB" in algos:
            choice = pb.select_model(emb)
            if choice not in current:
                choice = best_m
            rw = smap[choice]
            pb.update(emb, rw, choice)
            hist_pb.append((emb, rw, choice))
            o2b["PAK-UCB"].append(rw - base)
            opr["PAK-UCB"].append(int(choice == best_m))

        # BALROG
        if "BALROG" in algos:
            arm, delta_k, ucb, _, _ = nb.select_arm(i, t)
            if delta_k < epsilon and budget_knn > 0:
                # Complete query (all remaining models)
                for m in current:
                    nb.update(i, m, smap[m])
                    hist_nb.append((i, m, smap[m]))
                reward_k = smap[arm]
                budget_knn -= 1
            else:
                reward_k = smap[arm]
                nb.update(i, arm, reward_k)
                hist_nb.append((i, arm, reward_k))
            o2b["BALROG"].append(reward_k - base)
            opr["BALROG"].append(int(arm == best_m))
            budgets["BALROG"].append(BUDGET - budget_knn)

        # KNN-UCB (passif)
        if "KNN-UCB" in algos:
            arm, delta_l, _, _, _ = knn.select_arm(i, t)
            reward_l = smap[arm]
            knn.update(i, arm, reward_l)
            hist_knn.append((i, arm, reward_l))
            o2b["KNN-UCB"].append(reward_l - base)
            opr["KNN-UCB"].append(int(arm == best_m))
            budgets["KNN-UCB"].append(BUDGET - budget_kl)  # pour homogénéité de tracé

        # neuronal-s (optionnel)
        if "neuronal-s" in algos:
            x = X[i].to(device)
            logits, _ = f1(x.unsqueeze(0))
            # mapping indices courants
            cur_idx = list(range(len(current)))
            # feature phi
            # (NB: compute_phi dépend de f1 et du set courant)
            def compute_phi_local(x_):
                x0 = x_.unsqueeze(0)
                logits_l, h1 = f1(x0)
                grads = []
                for kk in cur_idx:
                    f1.zero_grad()
                    gW, gB = torch.autograd.grad(logits_l[0, kk], [f1.fc2.weight, f1.fc2.bias], retain_graph=True)
                    grads.append(torch.cat([gW.flatten(), gB.flatten()]))
                return torch.cat([h1.flatten(), torch.stack(grads).mean(0)], dim=0)

            phi = compute_phi_local(x)
            u_out = f2(phi.unsqueeze(0))
            sc_all = logits + u_out
            # score des modèles courants
            sc = sc_all[0, cur_idx]
            rel = int(sc.argmax())
            choice = current[rel]

            # marge top-2
            top2 = torch.topk(sc, k=min(2, len(current))).values
            if len(top2) == 1:
                margin = (top2[0] - top2[0]).item()
            else:
                margin = (top2[0] - top2[1]).item()

            beta = math.sqrt(len(current) * math.log((3 * N) / max(epsilon, 1e-8)) / max(t, 1))
            if margin < 2 * gamma_ns * beta and budget_ns > 0:
                rew, choice_eff = best_s, best_m  # on consomme le budget pour “oracle”
                budget_ns -= 1
            else:
                rew, choice_eff = smap[choice], choice

            o2b["neuronal-s"].append(rew - base)
            opr["neuronal-s"].append(int(choice_eff == best_m))
            budgets["neuronal-s"].append(BUDGET_NS - budget_ns)

            # update nets
            ut_vec = torch.tensor([smap[m] for m in current], device=device)
            # step f1
            opt1.zero_grad()
            p1, _ = f1(x.unsqueeze(0)); p2 = f2(phi.unsqueeze(0))
            sel = list(range(len(current)))
            loss1 = F.mse_loss(p1[:, sel], (ut_vec.unsqueeze(0) - p2[:, sel]).detach())
            loss1.backward(); opt1.step()
            # step f2
            opt2.zero_grad()
            with torch.no_grad():
                new_p1, _ = f1(x.unsqueeze(0))
            loss2 = F.mse_loss(f2(phi.unsqueeze(0)), (ut_vec.unsqueeze(0) - new_p1[:, sel]).detach())
            loss2.backward(); opt2.step()

    return o2b, opr, budgets

# -------- MAIN --------
if __name__ == "__main__":
    prompts, scores_map, embeddings, X = load_data(metadata_path, max_prompts)

    algos = ["Optimal", "Always", "Random", "PAK-UCB", "BALROG", "KNN-UCB"]  # add "neuronal-s" if desired
    all_o2b     = {a: [] for a in algos}
    all_opr     = {a: [] for a in algos if a != "Optimal"}
    budgets_acc = {a: [] for a in ["BALROG", "KNN-UCB", "neuronal-s"] if a in algos}

    for run in range(num_runs):
        print(f"Run {run+1}/{num_runs}")
        o2b_run, opr_run, buds = single_run(prompts, scores_map, embeddings, X, algos)
        for a in algos:
            all_o2b[a].append(o2b_run[a])
            if a != "Optimal":
                all_opr[a].append(opr_run[a])
        for a, b in buds.items():
            budgets_acc[a].append(b)

    # Save raw data for later plotting
    os.makedirs("data", exist_ok=True)
    save_path = os.path.join(
        "plots",
        "model_removal",
        "data",
        f"raw_data_model_removal_{dataset}_{T}_{num_runs}runs.pkl"
    )
    with open(save_path, "wb") as f:
        pickle.dump({
            "all_o2b":       all_o2b,
            "all_opr":       all_opr,
            "budgets_accum": budgets_acc
        }, f)
    print(f"Saved raw data to {save_path}")

    # Compute averages and plot
    avg_o2b = {a: np.mean(np.stack(all_o2b[a]), axis=0) for a in algos}
    avg_opr = {a: np.mean(np.stack(all_opr[a]), axis=0) for a in all_opr}
    avg_bud = {a: np.mean(np.stack(budgets_acc[a]), axis=0) for a in budgets_acc}

    styles = {
      "Optimal":    {"linestyle":"-.", "color":"#2ca02c"},
      "Always":     {"linestyle":"--", "color":"#1f77b4"},
      "Random":     {"linestyle":":",  "color":"#ff7f0e"},
      "PAK-UCB":    {"linestyle":"-",  "color":"#d62728"},
      "BALROG":   {"linestyle":"--","color":"#9467bd"},
      "KNN-UCB":    {"linestyle":"-.","color":"#8c564b"},
      "neuronal-s": {"linestyle":"-",  "color":"#17becf"},
    }

    fig, axes = plt.subplots(1,4, figsize=(24,6))

    # (1) Cumulative Regret
    ax = axes[0]
    for a in algos:
        if a != "Optimal":
            r = np.cumsum(avg_o2b["Optimal"] - avg_o2b[a])
            ax.plot(np.arange(1, len(r)+1), r, label=a, **styles[a])
    ax.axvline(x=T//3, color='grey', linestyle='--')
    ax.axvline(x=2*T//3, color='grey', linestyle='--')
    ax.text(T//3+100, ax.get_ylim()[1]*0.9, "-SSD-1B", rotation=90, color='grey')
    ax.text(2*T//3+100, ax.get_ylim()[1]*0.9, "-Koala", rotation=90, color='grey')
    ax.set_title("Cumulative Regret"); ax.set_xlabel("Iteration"); ax.set_ylabel("Regret")
    ax.legend(); ax.grid(True)

    # (2) Cumulative Avg OPR
    ax = axes[1]
    for a, v in avg_opr.items():
        cum = np.cumsum(v) / np.arange(1, len(v)+1)
        idx = np.linspace(0, len(cum)-1, 100, dtype=int)
        ax.plot(idx+1, cum[idx], label=a, **styles[a])
    ax.axvline(x=T//3, color='grey', linestyle='--')
    ax.axvline(x=2*T//3, color='grey', linestyle='--')
    ax.text(T//3+100, ax.get_ylim()[1]*0.9, "-SSD-1B", rotation=90, color='grey')
    ax.text(2*T//3+100, ax.get_ylim()[1]*0.9, "-Koala", rotation=90, color='grey')
    ax.set_title("Cumulative Avg OPR"); ax.set_xlabel("Iteration"); ax.set_ylabel("OPR")
    ax.legend(); ax.grid(True)

    # (3) Budget Consumption
    ax = axes[2]
    for a, b in avg_bud.items():
        ax.plot(b, label=a, **styles[a])
    ax.axvline(x=T//3, color='grey', linestyle='--')
    ax.axvline(x=2*T//3, color='grey', linestyle='--')
    ax.text(T//3+100, ax.get_ylim()[1]*0.9, "-SSD-1B", rotation=90, color='grey')
    ax.text(2*T//3+100, ax.get_ylim()[1]*0.9, "-Koala", rotation=90, color='grey')
    ax.set_title("Budget Consumption"); ax.set_xlabel("Iteration"); ax.set_ylabel("GT Queries")
    ax.legend(); ax.grid(True)

    # (4) Sliding-window Avg O2B
    ax = axes[3]
    w = T//10
    for a in algos:
        if len(avg_o2b[a]) >= w:
            mov = np.convolve(avg_o2b[a], np.ones(w)/w, mode="valid")
            idx = np.linspace(0, len(mov)-1, 100, dtype=int)
            ax.plot(np.arange(w, w+len(mov))[idx], mov[idx], label=a, **styles[a])
    ax.axvline(x=T//3, color='grey', linestyle='--')
    ax.axvline(x=2*T//3, color='grey', linestyle='--')
    ax.text(T//3+100, ax.get_ylim()[1]*0.9, "-SSD-1B", rotation=90, color='grey')
    ax.text(2*T//3+100, ax.get_ylim()[1]*0.9, "-Koala", rotation=90, color='grey')
    ax.set_title(f"{w}-Sliding Avg O2B"); ax.set_xlabel("Iteration"); ax.set_ylabel("Avg O2B")
    ax.legend(); ax.grid(True)

    plt.tight_layout()
    out_dir = "results/model_removal"
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(f"{out_dir}/results_{dataset}_{T}_{num_runs}runs.pdf")
    plt.close(fig)
