import os
import time
import copy
import random
import string
import torch
import numpy as np
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer
from peft import PeftModel, get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit
from torch.nn import functional as F
from sklearn.exceptions import UndefinedMetricWarning
import warnings
from sklearn.metrics import accuracy_score, f1_score


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 1995
BASE_MODEL = "t5-base"

TEACHER_ADAPTERS = {
    "sst2": "/adapter_checkpoint/sst2_adapter_checkpoint",
    "yelp": "/adapter_checkpoint/yelp_polarity_adapter_checkpoint",
    "amazon": "/adapter_checkpoint/amazon_polarity_adapter_checkpoint",
}

SAVE_STUDENT_DIR = "./student_adapter_checkpoint"


NUM_VIRTUAL_TOKENS = 10
BATCH_SIZE = 32
KL_EPOCHS = 15
LR = 5e-3
SAMPLES_FOR_DISTILL = 1000
EVAL_BATCH = 64
MAX_LENGTH = 128
EVAL_SAMPLES = 1000

ENABLE_DEFENSE = True



DEFENSE_CONFIG = {
    "hash_bits": 11,         
    "lambda_val": 0.0005,       
    "alpha": 8.0,            

    "beta": 0.19,            
    "batch_size": 100,      

    "coverage_weight": 0.05,   
    "novelty_weight": 0.35,     
    "spread_weight": 0.45     
}


torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = torch.device(DEVICE)


def make_peft_config(num_virtual_tokens=NUM_VIRTUAL_TOKENS):
    return PromptTuningConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        prompt_tuning_init=PromptTuningInit.RANDOM,
        num_virtual_tokens=num_virtual_tokens,
        tokenizer_name_or_path=BASE_MODEL,
    )

def load_teacher_pair(adapter_name):
    if adapter_name not in TEACHER_ADAPTERS:
        raise ValueError(f"Unknown adapter '{adapter_name}'")
    adapter_path = TEACHER_ADAPTERS[adapter_name]
    tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL)

    base_clean = T5ForConditionalGeneration.from_pretrained(BASE_MODEL).to(device)
    teacher_clean = PeftModel.from_pretrained(base_clean, adapter_path).to(device)
    teacher_clean.eval()

    base_defended = T5ForConditionalGeneration.from_pretrained(BASE_MODEL).to(device)
    teacher_defended = PeftModel.from_pretrained(base_defended, adapter_path).to(device)
    teacher_defended.eval()

    print(f"[Teacher] Loaded clean & defended adapters from {adapter_path}")
    return teacher_clean, teacher_defended, tokenizer

def instantiate_student_random():
    tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL)
    base_model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL).to(device)
    peft_config = make_peft_config()
    student = get_peft_model(base_model, peft_config).to(device)
    return student, tokenizer

@torch.no_grad
def teacher_first_token_probs(teacher_model, tokenizer, texts):
    enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True,
                    max_length=MAX_LENGTH).to(device)
    decoder_start = torch.tensor([[teacher_model.config.decoder_start_token_id]] * enc["input_ids"].shape[0]).to(device)
    outputs = teacher_model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
                            decoder_input_ids=decoder_start, return_dict=True)
    logits = outputs.logits[:, 0, :]
    probs = F.softmax(logits, dim=-1)
    return probs

def student_first_token_logits(student_model, tokenizer, texts):
    enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True,
                    max_length=MAX_LENGTH).to(device)
    decoder_start = torch.tensor([[student_model.config.decoder_start_token_id]] * enc["input_ids"].shape[0]).to(device)
    outputs = student_model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
                            decoder_input_ids=decoder_start, return_dict=True)
    logits = outputs.logits[:, 0, :]
    return logits

@torch.no_grad
def get_query_embeddings(model, tokenizer, texts):
    enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True,
                    max_length=MAX_LENGTH).to(device)
    outputs = model.encoder(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
    embeddings = outputs.last_hidden_state.mean(dim=1)
    return embeddings

def get_soft_prompt_embeddings(model):
    pe = getattr(model, "prompt_encoder", None)
    if pe is None:
        raise AttributeError("Model has no prompt_encoder")
    if hasattr(pe, "embedding"):
        return pe.embedding
    elif isinstance(pe, torch.nn.ModuleDict) and "default" in pe and hasattr(pe["default"], "embedding"):
        return pe["default"].embedding
    else:
        for _, module in pe.named_modules():
            if isinstance(module, torch.nn.Embedding):
                return module
        raise AttributeError("Could not locate prompt embeddings")


class Defense:
    def __init__(self, config):
        self.hash_bits = config.get("hash_bits", 12)
        self.num_buckets = 2 ** self.hash_bits
        self.lambda_val = config.get("lambda_val", 0.6)
        self.alpha = config.get("alpha", 8.0)
        self.beta = config.get("beta", 0.9)
        self.batch_size = config.get("batch_size", 100)

        self.coverage_weight = config.get("coverage_weight", 1.0)
        self.novelty_weight  = config.get("novelty_weight", 1.0)
        self.spread_weight   = config.get("spread_weight", 1.0)

        self.global_buckets = set()
        self.projection_matrix = None

        
        self.total_coverage_time = 0.0
        self.total_spread_time = 0.0
        self.total_noise_time = 0.0

    def compute_cost(self, coverage_fraction, delta_coverage, spread):
        normalized_coverage = coverage_fraction / (self.beta + 1e-6)
        raw_coverage = torch.exp(torch.tensor(normalized_coverage)) - 1
        coverage_penalty = self.lambda_val + self.coverage_weight * raw_coverage
        novelty_penalty = self.alpha * self.novelty_weight * delta_coverage
        spread_penalty  = self.alpha * self.spread_weight * min(spread / 10.0, 1.0)
        total_cost = coverage_penalty + novelty_penalty + spread_penalty
        return total_cost

    def get_lsh_buckets(self, vectors: torch.Tensor):
        projections = vectors @ self.projection_matrix
        binary_codes = (projections >= 0).int()
        powers = 2 ** torch.arange(binary_codes.size(1), device=vectors.device)
        hash_indices = (binary_codes * powers).sum(dim=1)
        bucket_ids = hash_indices % self.num_buckets
        return set(bucket_ids.cpu().tolist())

    def update_global_coverage(self, embeddings: torch.Tensor):
        t0 = time.time()
        vectors = embeddings.view(-1, embeddings.size(-1))
        if self.projection_matrix is None:
            self.projection_matrix = torch.randn(vectors.size(1), self.hash_bits).to(vectors.device)
        new_buckets = self.get_lsh_buckets(vectors)
        prev_count = len(self.global_buckets)
        self.global_buckets.update(new_buckets)
        curr_count = len(self.global_buckets)
        coverage_fraction = curr_count / self.num_buckets
        delta_coverage = max(curr_count - prev_count, 0) / self.batch_size
        t1 = time.time()
        self.total_coverage_time += (t1 - t0)
        return coverage_fraction, delta_coverage

    def batch_spread(self, embeddings: torch.Tensor):
        t0 = time.time()
        mean_vec = embeddings.mean(dim=0, keepdim=True)
        dists = torch.norm(embeddings - mean_vec, dim=1)
        spread = dists.mean().item()
        t1 = time.time()
        self.total_spread_time += (t1 - t0)
        return spread

    def perturb_soft_prompts(self, soft_prompts: torch.Tensor, query_embeddings: torch.Tensor):
        # coverage
        coverage_fraction, delta_coverage = self.update_global_coverage(query_embeddings)
        # spread
        spread = self.batch_spread(query_embeddings)
        # noise computation & addition
        t0 = time.time()
        noise_level = self.compute_cost(coverage_fraction, delta_coverage, spread)
        noise = torch.randn_like(soft_prompts) * noise_level
        perturbed = soft_prompts + noise
        t1 = time.time()
        self.total_noise_time += (t1 - t0)
        return perturbed


def random_gibberish(n=20, min_len=128, max_len=256):
    gibberish_texts = []
    for _ in range(n):
        length = random.randint(min_len, max_len)
        s = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
        gibberish_texts.append("gibberish: " + s)
    return gibberish_texts

def get_queries_for_distill(query_sources=("amazon",), n_samples=SAMPLES_FOR_DISTILL):
    texts = []
    per_source = max(1, n_samples // len(query_sources))

    for qs in query_sources:
        if qs == "amazon":
            ds = load_dataset("amazon_polarity", split="test")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["amazon review: " + ds[i]["content"] for i in sampled])

        elif qs == "yelp":
            ds = load_dataset("yelp_review_full", split="test")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["yelp review: " + ds[i]["text"] for i in sampled])

        elif qs == "rotten":
            ds = load_dataset("rotten_tomatoes", split="test")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["rotten review: " + ds[i]["text"] for i in sampled])

        elif qs == "sst2":
            ds = load_dataset("glue", "sst2", split="validation")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["sst2 sentence: " + ds[i]["sentence"] for i in sampled])

        elif qs == "imdb":
            ds = load_dataset("imdb", split="test")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["imdb review: " + ds[i]["text"] for i in sampled])

        elif qs == "ag_news":
            ds = load_dataset("ag_news", split="test")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["ag news: " + ds[i]["text"] for i in sampled])

        elif qs == "dbpedia":
            ds = load_dataset("dbpedia_14", split="test")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["dbpedia: " + ds[i]["content"] for i in sampled])

        elif qs == "snli":
            ds = load_dataset("snli", split="validation")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["snli premise: " + ds[i]["premise"] +
                          " hypothesis: " + ds[i]["hypothesis"] for i in sampled])

        elif qs == "mnli":
            ds = load_dataset("multi_nli", split="validation_matched")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["mnli premise: " + ds[i]["premise"] +
                          " hypothesis: " + ds[i]["hypothesis"] for i in sampled])

        elif qs == "quora":
            ds = load_dataset("quora", split="test")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["quora: " + ds[i]["questions"] for i in sampled])

        elif qs == "trec":
            ds = load_dataset("trec", split="test")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["trec question: " + ds[i]["text"] for i in sampled])

        elif qs == "tweet_eval":
            ds = load_dataset("tweet_eval", "sentiment", split="validation")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["tweet: " + ds[i]["text"] for i in sampled])

        elif qs == "cnn_dailymail":
            ds = load_dataset("cnn_dailymail", "3.0.0", split="validation")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["cnn article: " + ds[i]["article"] for i in sampled])

        elif qs == "squad":
            ds = load_dataset("squad", split="validation")
            sampled = random.sample(range(len(ds)), min(per_source, len(ds)))
            texts.extend(["squad context: " + ds[i]["context"] +
                          " question: " + ds[i]["question"] for i in sampled])

        elif qs == "gibberish":
            texts.extend(random_gibberish(per_source))

        else:
            raise ValueError(f"Unknown source {qs}")

    return texts


def distill_prompt(teacher, student, tokenizer, texts_for_distill, defense=None, epochs=KL_EPOCHS, top_k=5):
    trainable_params = [p for p in student.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(trainable_params, lr=LR)
    n = len(texts_for_distill)
    random.shuffle(texts_for_distill)

    embedding_layer = get_soft_prompt_embeddings(teacher)
    clean_prompt = embedding_layer.weight.data.clone().detach()

    total_query_time = 0.0
    total_defense_time = 0.0
    total_optimization_time = 0.0

    for ep in range(epochs):
        running_loss, batches = 0.0, 0

        for i in range(0, n, BATCH_SIZE):
            batch_texts = texts_for_distill[i:i+BATCH_SIZE]

            # Query teacher
            t0 = time.time()
            with torch.no_grad():
                teacher_probs = teacher_first_token_probs(teacher, tokenizer, batch_texts)
            t1 = time.time()
            batch_query_time = t1 - t0
            total_query_time += batch_query_time

            # Defense
            batch_defense_time = 0.0
            if defense is not None:
                t2 = time.time()
                query_embeddings = get_query_embeddings(teacher, tokenizer, batch_texts)
                try:
                    perturbed_prompt = defense.perturb_soft_prompts(clean_prompt, query_embeddings=query_embeddings)
                    embedding_layer.weight.data.copy_(perturbed_prompt)
                except Exception as e:
                    print(f"[Defense] Perturbation failed: {e}")
                t3 = time.time()
                batch_defense_time = t3 - t2
                total_defense_time += batch_defense_time

            # Student optimization
            t4 = time.time()
            student_logits = student_first_token_logits(student, tokenizer, batch_texts)
            if top_k is not None:
                topk_vals, topk_indices = torch.topk(student_logits, k=top_k, dim=-1)
                student_logits_filtered = topk_vals
                teacher_probs_filtered = torch.gather(teacher_probs, dim=1, index=topk_indices)
                teacher_probs_filtered = teacher_probs_filtered / teacher_probs_filtered.sum(dim=1, keepdim=True)
                student_log_probs = F.log_softmax(student_logits_filtered, dim=-1)
                loss = F.kl_div(student_log_probs, teacher_probs_filtered, reduction="batchmean")
            else:
                student_log_probs = F.log_softmax(student_logits, dim=-1)
                loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean")

            opt.zero_grad()
            loss.backward()
            opt.step()
            t5 = time.time()
            batch_optimization_time = t5 - t4
            total_optimization_time += batch_optimization_time

            running_loss += loss.item()
            batches += 1

            print(f"[Batch Timing] Query={batch_query_time:.2f}s, Defense={batch_defense_time:.2f}s, "
                  f"Optimization={batch_optimization_time:.2f}s")

        print(f"[Distill] Epoch {ep+1}/{epochs} avg KL loss: {running_loss/max(1,batches):.6f}")
        print(f"[Epoch Total Timing] Query={total_query_time:.2f}s, Defense={total_defense_time:.2f}s, "
              f"Optimization={total_optimization_time:.2f}s")
        if defense is not None:
            print(f"[Defense Breakdown Total] Coverage={defense.total_coverage_time:.2f}s, "
                  f"Spread={defense.total_spread_time:.2f}s, Noise={defense.total_noise_time:.2f}s")

    os.makedirs(SAVE_STUDENT_DIR, exist_ok=True)
    student.save_pretrained(SAVE_STUDENT_DIR)
    print(f"[Distill] Saved student adapter to {SAVE_STUDENT_DIR}")


def eval_on_yelp(model, tokenizer, split="test", n_samples=None):
    ds = load_dataset("yelp_polarity", split=split)
    if n_samples:
        ds = ds.select(range(n_samples))
    texts = ["yelp review: " + t for t in ds["text"]]
    labels = ds["label"]
    preds = []
    batch_size = EVAL_BATCH
    model.eval()
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        enc = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        gen = model.generate(**enc, max_length=8)
        decoded = tokenizer.batch_decode(gen, skip_special_tokens=True)
        for d in decoded:
            preds.append(1 if "positive" in d.lower() else 0)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    return {"accuracy": acc, "f1": f1}


def eval_on_dataset(model, tokenizer, ds_name="amazon", split="test", n_samples=EVAL_SAMPLES):
    if ds_name == "amazon":
        ds = load_dataset("amazon_polarity", split=split)
        texts = ["amazon review: " + t for t in ds["content"]]
        labels = ds["label"]
    elif ds_name == "rotten":
        ds = load_dataset("rotten_tomatoes", split=split)
        texts = ["rotten review: " + t for t in ds["text"]]
        labels = ds["label"]
    elif ds_name == "sst2":
        ds = load_dataset("glue", "sst2", split=split)
        texts = ["sst2 sentence: " + t for t in ds["sentence"]]
        labels = ds["label"]
    elif ds_name == "imdb":
        ds = load_dataset("imdb", split=split)
        texts = ["imdb review: " + t for t in ds["text"]]
        labels = ds["label"]
    elif ds_name == "yelp":
        ds = load_dataset("yelp_polarity", split=split)
        texts = ["yelp review: " + t for t in ds["text"]]
        labels = ds["label"]
    else:
        raise ValueError(f"Unknown dataset {ds_name}")

    if n_samples:
        texts = texts[:n_samples]
        labels = labels[:n_samples]

    preds = []
    model.eval()
    for i in range(0, len(texts), EVAL_BATCH):
        batch_texts = texts[i:i + EVAL_BATCH]
        enc = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(device)
        gen = model.generate(**enc, max_length=8)
        decoded = tokenizer.batch_decode(gen, skip_special_tokens=True)
        for d in decoded:
            preds.append(1 if "positive" in d.lower() else 0)

    if len(set(labels)) > 2:
        f1 = f1_score(labels, preds, average="weighted")
    else:
        f1 = f1_score(labels, preds, average="binary")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1}

def eval_model_on_rotten(model, tokenizer, split="test", n_samples=None):
    ds = load_dataset("rotten_tomatoes", split=split)
    if n_samples:
        ds = ds.select(range(n_samples))
    texts = ["review: " + t for t in ds["text"]]
    labels = ds["label"]
    preds = []
    model.eval()
    for i in range(0, len(texts), EVAL_BATCH):
        batch_texts = texts[i:i+EVAL_BATCH]
        enc = tokenizer(batch_texts, return_tensors="pt", padding=True,
                        truncation=True, max_length=512).to(device)
        gen = model.generate(**enc, max_length=8)
        decoded = tokenizer.batch_decode(gen, skip_special_tokens=True)
        for d in decoded:
            preds.append(1 if "positive" in d.lower() else 0)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    return {"accuracy": acc, "f1": f1}

def eval_model_on_sst2(model, tokenizer, split="validation", n_samples=None):
    ds = load_dataset("glue", "sst2", split=split)
    if n_samples:
        ds = ds.select(range(n_samples))
    texts = ["sst2 sentence: " + t for t in ds["sentence"]]
    labels = ds["label"]

    preds = []
    model.eval()
    for i in range(0, len(texts), EVAL_BATCH):
        batch_texts = texts[i:i+EVAL_BATCH]
        enc = tokenizer(batch_texts, return_tensors="pt", padding=True,
                        truncation=True, max_length=512).to(device)
        gen = model.generate(**enc, max_length=8)
        decoded = tokenizer.batch_decode(gen, skip_special_tokens=True)
        for d in decoded:
            preds.append(1 if "positive" in d.lower() else 0)

    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    return {"accuracy": acc, "f1": f1}


def main(teacher_name="amazon", query_sources=("amazon",)):
    if teacher_name not in TEACHER_ADAPTERS:
        raise ValueError(f"Unknown teacher '{teacher_name}'")
    teacher_clean, teacher_defended, tokenizer = load_teacher_pair(teacher_name)
    defense = Defense(DEFENSE_CONFIG) if ENABLE_DEFENSE else None

    texts_for_distill = get_queries_for_distill(query_sources=query_sources, n_samples=SAMPLES_FOR_DISTILL)
    print(f"[Main] Using {len(texts_for_distill)} queries for distillation.")

    student, tokenizer_student = instantiate_student_random()
    assert tokenizer.get_vocab() == tokenizer_student.get_vocab()

    print("[Main] Starting distillation using defended teacher...")
    distill_prompt(teacher_defended, student, tokenizer, texts_for_distill, defense=defense, epochs=KL_EPOCHS, top_k=None)

    print("[Main] Reloading student from checkpoint for evaluation...")
    base_model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL).to(device)
    student_reloaded = PeftModel.from_pretrained(base_model, SAVE_STUDENT_DIR).to(device)
    student_reloaded.eval()



    if teacher_name == "yelp":
        print("\n[Eval] CLEAN TEACHER (baseline)")
        print(f"Teacher Clean {teacher_name.upper()} ->", eval_on_yelp(teacher_clean, tokenizer, n_samples=1000))

        print("\n[Eval] STUDENT")
        print(f"{teacher_name.upper()} -> {eval_on_yelp(student_reloaded, tokenizer, n_samples=1000)}")

    elif teacher_name == "rotten":
        print("\n[Eval] CLEAN TEACHER (baseline)")
        print(f"Teacher Clean {teacher_name.upper()} ->", eval_model_on_rotten(teacher_clean, tokenizer, n_samples=1000))

        print("\n[Eval] STUDENT")
        print(f"{teacher_name.upper()} -> {eval_model_on_rotten(student_reloaded, tokenizer, n_samples=1000)}")

    elif teacher_name == "sst2":
        print("\n[Eval] CLEAN TEACHER (baseline)")
        print(f"Teacher Clean {teacher_name.upper()} ->", eval_model_on_sst2(teacher_clean, tokenizer, n_samples=1000))

        print("\n[Eval] STUDENT")
        print(f"{teacher_name.upper()} -> {eval_model_on_sst2(student_reloaded, tokenizer, n_samples=1000)}")


    else:
        print("\n[Eval] CLEAN TEACHER (baseline)")
        print(f"Teacher Clean {teacher_name.upper()} ->", eval_on_dataset(teacher_clean, tokenizer, ds_name=teacher_name, n_samples=1000))

        print("\n[Eval] STUDENT")
        print(f"{teacher_name.upper()} -> {eval_on_dataset(student_reloaded, tokenizer, ds_name=teacher_name)}")

if __name__ == "__main__":
    print("Using diverse queries")
    diverse_sources = (
        "yelp", "imdb", "ag_news", "rotten", "sst2",
        "dbpedia", "snli", "mnli", "tweet_eval", "gibberish"
    )
    main(teacher_name="amazon", query_sources=diverse_sources)

    print("Only legitimate queries (amazon):")
    main(teacher_name="amazon", query_sources=("amazon",))

    