import os
import logging
import random
from typing import Tuple

import numpy as np
import pandas as pd
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
from peft import PeftModel
from sklearn.metrics import accuracy_score
from tqdm import tqdm


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "t5-base"
SST2_ADAPTER_PATH = "/adapter_checkpoint"  
PROMPT_CSV = "/prompts.csv"

NUM_QUERIES = 1000
BATCH_SIZE = 8
GAP_THRESHOLD = 5.0
MIN_SINGULAR_VAL = 1e-6

ENABLE_DEFENSE = False  
DEFENSE_PARAMS = dict(lambda_val=0.1, alpha=5.0, beta=0.9)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
logger = logging.getLogger("attack")


def cost_function(coverage_fraction, lambda_val=0.9, alpha=5, beta=0.5):
    exponent = np.log(alpha / lambda_val) * coverage_fraction * (1 / beta)
    return lambda_val * (np.exp(exponent) - 1)

def compute_rmse(true_matrix: np.ndarray, predicted_matrix: np.ndarray) -> float:
    return float(np.sqrt(np.mean((true_matrix - predicted_matrix) ** 2)))

def add_noise_to_logits(logits_matrix, noise_level=0.1):
    noise = np.random.normal(0, noise_level, logits_matrix.shape)
    return logits_matrix + noise

def track_embedding_coverage(data, proj_list):
    num_bits = proj_list[0].shape[1]
    total_buckets = 2 ** num_bits
    bit_weights = 2 ** np.arange(num_bits-1, -1, -1)

    occupied_list = []
    for proj in proj_list:
        hash_bits = (data @ proj > 0).astype(np.uint8)
        bucket_ids = hash_bits @ bit_weights
        unique_buckets = np.unique(bucket_ids)
        occupied_list.append(len(unique_buckets))

    avg_coverage = np.mean(occupied_list) / total_buckets
    return avg_coverage * 100, np.mean(occupied_list), total_buckets


def load_oracle_t5_with_adapter(adapter_dir: str, device: str = DEVICE):
    logger.info(f"Loading base model {MODEL_NAME} and adapter from {adapter_dir} ...")
    tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
    base = T5ForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)
    model = PeftModel.from_pretrained(base, adapter_dir).to(device)
    model.eval()
    logger.info("Loaded oracle model (T5 + adapter).")
    return model, tokenizer


def get_first_decoder_token_logits_for_dataset(model: torch.nn.Module,
                                               tokenizer: T5Tokenizer,
                                               dataset_split,
                                               num_queries: int = NUM_QUERIES,
                                               batch_size: int = BATCH_SIZE,
                                               device: str = DEVICE) -> Tuple[np.ndarray, np.ndarray]:
    all_logits, all_labels = [], []
    texts = dataset_split[:num_queries]

    for i in tqdm(range(0, len(texts), batch_size), desc="Collecting logits"):
        batch_texts = texts[i:i+batch_size]
        enc = tokenizer(batch_texts, return_tensors="pt", padding=True,
                        truncation=True, max_length=128).to(device)

        dec_input_ids = torch.full(
            (enc["input_ids"].size(0), 1),
            model.config.decoder_start_token_id,
            dtype=torch.long,
            device=device
        )

        with torch.no_grad():
            out = model(
                input_ids=enc["input_ids"],
                attention_mask=enc["attention_mask"],
                decoder_input_ids=dec_input_ids
            )
            logits = out.logits[:, -1, :].cpu().numpy()

        all_logits.append(logits)
        all_labels.append(np.full(len(batch_texts), -1))  

    return np.vstack(all_logits), np.concatenate(all_labels)


def extract_final_layer(logits_matrix: np.ndarray,
                        model_W: np.ndarray,
                        svd_rank: int = None,
                        gap_threshold: float = GAP_THRESHOLD,
                        min_singular_val: float = MIN_SINGULAR_VAL):
    A = logits_matrix.T 
    U, S, Vh = np.linalg.svd(A, full_matrices=False)

    est_rank = svd_rank
    if est_rank is None:
        if S.size >= 2:
            gaps = S[:-1] / (S[1:] + 1e-12)
            est_rank = int(np.argmax(gaps) + 1)
            if gaps.max() < gap_threshold:
                est_rank = int(np.sum(S > min_singular_val))
                if est_rank == 0:
                    est_rank = 1
        else:
            est_rank = 1

    U_h = U[:, :est_rank]
    S_h = np.diag(S[:est_rank])
    W_tilde = U_h @ S_h

    if W_tilde.shape[0] != model_W.shape[0]:
        raise ValueError("Shape mismatch: W_tilde rows must equal model_W rows.")

    G, *_ = np.linalg.lstsq(W_tilde, model_W, rcond=None)
    W_rec = W_tilde @ G
    rms = compute_rmse(model_W, W_rec)
    return rms, W_rec, W_tilde, G, est_rank


def main():
    device = torch.device(DEVICE)
    oracle_model, tokenizer = load_oracle_t5_with_adapter(SST2_ADAPTER_PATH, device=device)


    df = pd.read_csv(PROMPT_CSV)
    prompts = df["prompt"].astype(str).tolist()
    logger.info(f"Prompt CSV size: {len(prompts)}. Using up to NUM_QUERIES={NUM_QUERIES}")


    logits_matrix, labels = get_first_decoder_token_logits_for_dataset(
        oracle_model, tokenizer, prompts,
        num_queries=NUM_QUERIES, batch_size=BATCH_SIZE, device=device
    )
    logger.info(f"Collected logits matrix shape: {logits_matrix.shape}")

    out_embeddings = (oracle_model.base_model.get_output_embeddings()
                      if hasattr(oracle_model, "base_model")
                      else oracle_model.get_output_embeddings())
    model_W = out_embeddings.weight.detach().cpu().numpy()
    logger.info(f"Model output embedding shape: {model_W.shape}")


    if ENABLE_DEFENSE:
        logger.info("[DEFENSE] Enabled: tracking coverage and applying exponential noise")
        proj_list = [np.random.randn(logits_matrix.shape[1], 6) for _ in range(20)]
        coverage_fraction, occupied, total = track_embedding_coverage(logits_matrix, proj_list)
        logger.info(f"[DEFENSE] Coverage: {coverage_fraction:.4f}% ({occupied}/{total} buckets)")

        noise_level = cost_function(coverage_fraction/100.0,
                                    **DEFENSE_PARAMS)
        logger.info(f"[DEFENSE] Noise level = {noise_level:.6f}")
        logits_matrix = add_noise_to_logits(logits_matrix, noise_level=noise_level)

        preds = np.argmax(logits_matrix, axis=1)
        acc = accuracy_score(labels, preds)
        logger.info(f"[DEFENSE] Noisy predictions accuracy (dummy labels=-1 ignored): {acc:.4f}")


    rms, W_rec, W_tilde, G, est_dim = extract_final_layer(
        logits_matrix,
        model_W,
        svd_rank=None,
        gap_threshold=GAP_THRESHOLD,
        min_singular_val=MIN_SINGULAR_VAL
    )
    print(f"Extraction finished. RMSE: {rms:.6f}, estimated rank: {est_dim}")

    per_row_rmse = np.sqrt(np.mean((model_W - W_rec) ** 2, axis=1))
    logger.info(f"Per-token RMSE stats: mean={per_row_rmse.mean():.6f}, "
                f"std={per_row_rmse.std():.6f}, max={per_row_rmse.max():.6f}")


    out_dir = "attack_output"
    os.makedirs(out_dir, exist_ok=True)
    np.save(os.path.join(out_dir, "logits_matrix.npy"), logits_matrix)
    np.save(os.path.join(out_dir, "model_W.npy"), model_W)
    np.save(os.path.join(out_dir, "W_rec.npy"), W_rec)
    logger.info(f"Saved logits_matrix.npy, model_W.npy, W_rec.npy to {out_dir}")

if __name__ == "__main__":
    main()
