import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import T5ForConditionalGeneration, T5Tokenizer
from safetensors.torch import load_file
from sklearn.linear_model import LinearRegression


device = "cuda" if torch.cuda.is_available() else "cpu"
print("[INFO] Using device:", device)


model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device).eval()


query_text = "Hey There!"
batch_size = 32
queries = [query_text] * batch_size


def black_box_api_prob(query_batch, prompt_emb):
    with torch.no_grad():
        input_ids = tokenizer(query_batch, padding=True, return_tensors="pt").input_ids.to(device)
        input_embeds = model.encoder.embed_tokens(input_ids)

        if prompt_emb is not None and prompt_emb.shape[0] > 0:
            sp_batch = prompt_emb.unsqueeze(0).expand(input_embeds.size(0), -1, -1)
            inputs_embeds = torch.cat([sp_batch, input_embeds], dim=1)
        else:
            inputs_embeds = input_embeds

        attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.long, device=device)
        decoder_start = torch.full((len(query_batch), 1), model.config.decoder_start_token_id, dtype=torch.long, device=device)

        outputs = model(
            encoder_outputs=model.encoder(inputs_embeds=inputs_embeds, attention_mask=attention_mask),
            decoder_input_ids=decoder_start,
            return_dict=True,
            use_cache=False
        )

        logits_first = outputs.logits[:, 0, :]
        label_tokens = ("negative", "positive")
        label_token_ids = [tokenizer.encode(l, add_special_tokens=False)[0] for l in label_tokens]
        label_token_ids = torch.tensor(label_token_ids, device=device)
        probs = torch.softmax(logits_first[:, label_token_ids], dim=-1)
        return probs.cpu().numpy()


adapter_path = "/sst2.pt"  

if adapter_path.endswith(".npy"):
    try:
        soft_prompt_np = np.load(adapter_path) 
        soft_prompt = torch.tensor(soft_prompt_np, device=device, dtype=torch.float32)
    except Exception as e:
        raise RuntimeError(f"Failed to load .npy soft prompt from {adapter_path}: {e}")

elif adapter_path.endswith(".safetensors"):
    ckpt = load_file(adapter_path)
    print("[INFO] Inspecting adapter keys:", list(ckpt.keys()))
    if "prompt_embeddings" in ckpt:
        soft_prompt = ckpt["prompt_embeddings"].to(device).float()
    else:
        raise RuntimeError(f"No 'prompt_embeddings' key found in {adapter_path}")

elif adapter_path.endswith(".pt"):
    ckpt = torch.load(adapter_path, map_location=device)
    if isinstance(ckpt, torch.Tensor):
        soft_prompt = ckpt.float().to(device)
    elif isinstance(ckpt, dict):
        if "prompt_embeddings" in ckpt:
            soft_prompt = ckpt["prompt_embeddings"].float().to(device)
        elif "state_dict" in ckpt and "prompt_embeddings" in ckpt["state_dict"]:
            soft_prompt = ckpt["state_dict"]["prompt_embeddings"].float().to(device)
        else:
            raise RuntimeError(f"Unknown .pt structure, available keys: {list(ckpt.keys())}")
    else:
        raise RuntimeError(f"Unsupported .pt format type: {type(ckpt)}")

else:
    raise RuntimeError(f"Unsupported file type: {adapter_path}")


m_true = soft_prompt.shape[0]
print(f"[INFO] Tuned soft prompt length={m_true}, dim={soft_prompt.shape[1]}")


max_length = 100
n_repeats = 300   
results_mean = []
results_std = []


for m in range(max_length + 1):
    sp = torch.randn(m, model.config.d_model, device=device) if m > 0 else torch.empty(0, model.config.d_model, device=device)


    for _ in range(3):
        _ = black_box_api_prob(queries, sp)
        torch.cuda.synchronize()

    timings = []
    for _ in range(n_repeats):
        t0 = time.time()
        _ = black_box_api_prob(queries, sp)
        torch.cuda.synchronize()  
        t1 = time.time()
        timings.append((t1 - t0) * 1000) 

    mean_t = np.mean(timings)
    std_t = np.std(timings)
    results_mean.append(mean_t)
    results_std.append(std_t)
    print(f"Random Prompt length {m:3d} -> mean {mean_t:.3f} ms, std {std_t:.3f} ms")


for _ in range(4):
    _ = black_box_api_prob(queries, soft_prompt)
    torch.cuda.synchronize()

timings_soft = []
for _ in range(n_repeats):
    t0 = time.time()
    _ = black_box_api_prob(queries, soft_prompt)
    torch.cuda.synchronize()
    t1 = time.time()
    timings_soft.append((t1 - t0) * 1000)

mean_soft = np.mean(timings_soft)
std_soft = np.std(timings_soft)
print(f"[TUNED] Soft prompt -> mean {mean_soft:.3f} ms, std {std_soft:.3f} ms")


X = np.arange(10, max_length + 1).reshape(-1, 1) 
y = np.array(results_mean[10:])
weights = 1 / (np.array(results_std[10:])**2 + 1e-8)

reg = LinearRegression().fit(X, y, sample_weight=weights)
slope = reg.coef_[0]
intercept = reg.intercept_

estimated_length_cont = (mean_soft - intercept) / slope

estimated_length = int(round(estimated_length_cont))
estimate_range = (max(0, estimated_length-2), estimated_length+2)

print(f"[REGRESSION ESTIMATE] Soft prompt length ≈ {estimated_length} (range {estimate_range[0]}-{estimate_range[1]})")
print(f"[TRUE] Tuned soft prompt length: {m_true}")
