import os
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer
from peft import PeftModel, get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit
from datasets import load_dataset
import evaluate
from sklearn.metrics import accuracy_score, f1_score
import numpy as np

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODEL = "t5-base"

ADAPTER_PATHS = {
    "sst2": "/sst2_adapter_checkpoint",
    "yelp": "/yelp_polarity_adapter_checkpoint",
    "amazon": "/amazon_polarity_adapter_checkpoint",
    "rotten": "/rotten_tomatoes_adapter_checkpoint",
    "cnn": "/cnn_summarization",
    "arxiv": "/arxiv_prompt_adapter",
}

STUDENT_ADAPTER_PATHS = {
    "amazon": "/student_adapter_checkpoint"
}

NUM_SAMPLES_PER_TASK = 1000
BATCH_SIZE = 8
EVAL_BATCH_SIZE = 8
EPOCHS = 8
LR = 1e-4
SEED = 1339
TRANSFORMER_DIM = 768
TRANSFORMER_LAYERS = 4
TRANSFORMER_HEADS = 8
TOP_K = None
MAX_INPUT_LENGTH = 256
GEN_MAX_LENGTH = 128

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = torch.device(DEVICE)

def load_peft_adapter(adapter_dir):
    tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL)
    base = T5ForConditionalGeneration.from_pretrained(BASE_MODEL).to(device)
    if DEVICE == "cuda":
        base.half()
    model = PeftModel.from_pretrained(base, adapter_dir).to(device)
    model.eval()
    return model, tokenizer

def get_prompt_embedding_cpu(model):
    pe = model.prompt_encoder["default"].embedding.weight.detach().cpu().float()
    return pe

@torch.no_grad()
def get_first_token_probs(model, tok, dataset_name, num_samples=4, top_k=None, batch_size=BATCH_SIZE):
    if dataset_name == "sst2":
        ds = load_dataset("glue", "sst2", split=f"train[:{num_samples}]")
        texts = [f"sst2 sentence: {s}" for s in ds["sentence"]]
    elif dataset_name == "yelp":
        ds = load_dataset("yelp_polarity", split=f"train[:{num_samples}]")
        texts = [f"yelp review: {t}" for t in ds["text"]]
    elif dataset_name == "amazon":
        ds = load_dataset("amazon_polarity", split=f"train[:{num_samples}]")
        texts = [f"amazon review: {t}" for t in ds["content"]]
    elif dataset_name == "rotten":
        ds = load_dataset("rotten_tomatoes", split=f"train[:{num_samples}]")
        texts = [f"rotten tomatoes review: {t}" for t in ds["text"]]
    elif dataset_name == "cnn":
        ds = load_dataset("cnn_dailymail", "3.0.0", split=f"train[:{num_samples}]")
        texts = [f"cnn article: {t}" for t in ds["article"]]
    elif dataset_name == "arxiv":
        ds = load_dataset("scientific_papers", "arxiv", split=f"train[:{num_samples}]")
        texts = [f"arxiv article: {t}" for t in ds["article"]]
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    enc_all = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_LENGTH)
    all_probs = []

    for i in range(0, enc_all["input_ids"].size(0), batch_size):
        batch_enc = {k: v[i:i+batch_size].to(device) for k, v in enc_all.items()}
        dec_input_ids = torch.full((batch_enc["input_ids"].size(0), 1),
                                   model.config.decoder_start_token_id,
                                   dtype=torch.long,
                                   device=device)
        if DEVICE == "cuda":
            with torch.autocast("cuda", dtype=torch.float16):
                out = model(input_ids=batch_enc["input_ids"], attention_mask=batch_enc["attention_mask"], decoder_input_ids=dec_input_ids)
        else:
            out = model(input_ids=batch_enc["input_ids"], attention_mask=batch_enc["attention_mask"], decoder_input_ids=dec_input_ids)

        probs = F.softmax(out.logits[:, -1, :], dim=-1).cpu()
        if top_k is not None:
            top_vals, top_idx = torch.topk(probs, top_k, dim=-1)
            new_probs = torch.zeros_like(probs)
            new_probs.scatter_(1, top_idx, top_vals)
            probs = new_probs
        all_probs.append(probs)
        del out
        torch.cuda.empty_cache() if DEVICE == "cuda" else None

    return torch.cat(all_probs, dim=0)

class ProbToPromptDataset(Dataset):
    def __init__(self, prob_feat_tensors, prompt_target_tensors):
        self.probs = torch.cat(prob_feat_tensors, dim=0)
        self.prompts = torch.cat(prompt_target_tensors, dim=0)
        assert self.probs.size(0) == self.prompts.size(0)
    def __len__(self):
        return self.probs.size(0)
    def __getitem__(self, idx):
        return self.probs[idx], self.prompts[idx]

class ResidualTransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048):
        super().__init__()
        self.layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, activation="relu")
        self.norm = nn.LayerNorm(d_model)
    def forward(self, x):
        out = self.layer(x)
        return self.norm(x + out)

class ProbToPromptResNet(nn.Module):
    def __init__(self, prob_dim, prompt_len, prompt_hidden_dim, d_model=TRANSFORMER_DIM, nhead=TRANSFORMER_HEADS, nlayers=TRANSFORMER_LAYERS):
        super().__init__()
        self.input_proj = nn.Linear(prob_dim, d_model)
        self.blocks = nn.ModuleList([ResidualTransformerBlock(d_model, nhead) for _ in range(nlayers)])
        self.decoder = nn.Sequential(
            nn.Linear(d_model, d_model*2),
            nn.ReLU(),
            nn.Linear(d_model*2, prompt_len*prompt_hidden_dim)
        )
        self.prompt_len = prompt_len
        self.prompt_hidden_dim = prompt_hidden_dim

    def forward(self, prob_vec):
        x = self.input_proj(prob_vec)
        x_seq = x.unsqueeze(0)
        for block in self.blocks:
            x_seq = block(x_seq)
        x_seq = x_seq.squeeze(0)
        out = self.decoder(x_seq)
        out = out.view(prob_vec.size(0), self.prompt_len, self.prompt_hidden_dim)
        return out, x_seq

def _sanitize_prompt_embedding(pe: torch.Tensor) -> torch.Tensor:
    pe = pe.detach().to(torch.float32).contiguous()
    pe = torch.nan_to_num(pe, nan=0.0, posinf=0.0, neginf=0.0)
    return pe

def eval_with_prompt(prompt_embedding, dataset_name, tokenizer, device=DEVICE):
    pe = _sanitize_prompt_embedding(prompt_embedding).to(device)
    base_model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL).to(device)
    if DEVICE == "cuda":
        base_model.half()
    peft_config = PromptTuningConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        prompt_tuning_init=PromptTuningInit.RANDOM,
        num_virtual_tokens=pe.shape[0]//2,
        tokenizer_name_or_path=BASE_MODEL
    )
    model = get_peft_model(base_model, peft_config)
    with torch.no_grad():
        model.prompt_encoder["default"].embedding.weight.copy_(pe)
    model.eval()

    is_summarization = False
    if dataset_name == "sst2":
        ds = load_dataset("glue", "sst2", split="validation[:1000]")
        texts = ["sst2 sentence: " + s for s in ds["sentence"]]
        labels = ds["label"]
    elif dataset_name == "rotten":
        ds = load_dataset("rotten_tomatoes", split="validation[:1000]")
        texts = ["rotten tomatoes review: " + t for t in ds["text"]]
        labels = ds["label"]
    elif dataset_name == "amazon":
        ds = load_dataset("amazon_polarity", split="test[:1000]")
        texts = ["amazon review: " + t for t in ds["content"]]
        labels = ds["label"]
    elif dataset_name == "yelp":
        ds = load_dataset("yelp_polarity", split="test[:1000]")
        texts = ["yelp review: " + t for t in ds["text"]]
        labels = ds["label"]
    elif dataset_name == "cnn":
        ds = load_dataset("cnn_dailymail", "3.0.0", split=f"test[:500]")
        texts = ["cnn article: " + t for t in ds["article"]]
        references = ds["highlights"]
        is_summarization = True
    elif dataset_name == "arxiv":
        ds = load_dataset("scientific_papers", "arxiv", split=f"test[:500]")
        texts = ["arxiv article: " + t for t in ds["article"]]
        references = ds["abstract"]
        is_summarization = True
    else:
        raise ValueError(f"Unknown eval dataset: {dataset_name}")

    if not is_summarization:
        preds = []
        with torch.inference_mode():
            for i in range(0, len(texts), EVAL_BATCH_SIZE):
                batch_texts = texts[i:i+EVAL_BATCH_SIZE]
                enc = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
                enc = {k: v.to(device) for k, v in enc.items()}
                dec_input_ids = torch.full((enc["input_ids"].size(0), 1),
                                           model.config.decoder_start_token_id,
                                           dtype=torch.long,
                                           device=device)
                if DEVICE == "cuda":
                    with torch.autocast("cuda", dtype=torch.float16):
                        out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], decoder_input_ids=dec_input_ids)
                else:
                    out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], decoder_input_ids=dec_input_ids)

                first_token_ids = torch.argmax(out.logits[:, -1, :], dim=-1).unsqueeze(1)
                decoded = tokenizer.batch_decode(first_token_ids, skip_special_tokens=True)
                for d in decoded:
                    preds.append(1 if "positive" in d.lower() else 0)

                del out
                torch.cuda.empty_cache() if DEVICE == "cuda" else None

        acc = accuracy_score(labels, preds)
        f1 = f1_score(labels, preds)
        return {"accuracy": acc, "f1": f1}

    rouge_metric = evaluate.load("rouge")
    bleu_metric = evaluate.load("bleu")
    preds = []
    refs = []
    with torch.inference_mode():
        for i in range(0, len(texts), EVAL_BATCH_SIZE):
            batch_texts = texts[i:i+EVAL_BATCH_SIZE]
            enc = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_LENGTH)
            enc = {k: v.to(device) for k, v in enc.items()}

            if DEVICE == "cuda":
                with torch.autocast("cuda", dtype=torch.float16):
                    outs = model.generate(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
                                          max_length=GEN_MAX_LENGTH, num_beams=4, early_stopping=True)
            else:
                outs = model.generate(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
                                      max_length=GEN_MAX_LENGTH, num_beams=4, early_stopping=True)

            decoded = tokenizer.batch_decode(outs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            preds.extend(decoded)
            refs.extend(references[i:i+len(decoded)])
            del outs
            torch.cuda.empty_cache() if DEVICE == "cuda" else None

    rouge_res = rouge_metric.compute(predictions=preds, references=refs, use_stemmer=True)
    rouge_summary = {
        "rouge1": float(rouge_res.get("rouge1", 0.0)),
        "rouge2": float(rouge_res.get("rouge2", 0.0)),
        "rougeL": float(rouge_res.get("rougeL", 0.0))
    }

    tokenized_preds = [p.split() for p in preds]
    tokenized_refs = [[r.split()] for r in refs]
    try:
        bleu_res = bleu_metric.compute(predictions=tokenized_preds, references=tokenized_refs)
        bleu_score = float(bleu_res.get("bleu", 0.0))
    except Exception:
        bleu_score = 0.0

    return {"rouge": rouge_summary, "bleu": bleu_score}

def compute_similarity_metrics(recon_prompt, teacher_prompt):
    recon = recon_prompt.view(-1)
    teacher = teacher_prompt.view(-1)
    l2 = float(F.mse_loss(recon, teacher).item() ** 0.5)
    cos_sim = float(F.cosine_similarity(recon, teacher, dim=0).item())
    return l2, cos_sim

def transform_prompts(prompt_embedding, num_repeats, noise_scale=0.1, scale_range=(0.9,1.1)):
    prompt_len, hidden_dim = prompt_embedding.shape
    repeated = prompt_embedding.unsqueeze(0).repeat(num_repeats, 1, 1)
    scales = torch.rand(num_repeats, 1, 1) * (scale_range[1]-scale_range[0]) + scale_range[0]
    repeated = repeated * scales
    if noise_scale > 0:
        repeated = repeated + torch.randn_like(repeated) * noise_scale
    return repeated.view(num_repeats, prompt_len * hidden_dim)

def main():
    tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL)
    victim_adapters = {name: load_peft_adapter(path) for name, path in ADAPTER_PATHS.items()}
    stolen_adapters = {name: load_peft_adapter(path) for name, path in STUDENT_ADAPTER_PATHS.items()}

    stolen_train_probs, stolen_train_targets = [], []
    prompt_len = None
    hidden_dim = None

    for stolen_name, (stolen_model, stolen_tok) in stolen_adapters.items():
        stolen_prompt = get_prompt_embedding_cpu(stolen_model)
        prompt_len, hidden_dim = stolen_prompt.size()
        stolen_probs = get_first_token_probs(stolen_model, stolen_tok, stolen_name, NUM_SAMPLES_PER_TASK, top_k=TOP_K, batch_size=BATCH_SIZE)
        stolen_targets = transform_prompts(stolen_prompt, stolen_probs.size(0))
        stolen_train_probs.append(stolen_probs)
        stolen_train_targets.append(stolen_targets)

    train_ds = ProbToPromptDataset(stolen_train_probs, stolen_train_targets)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    model = ProbToPromptResNet(prob_dim=stolen_train_probs[0].size(1), prompt_len=prompt_len, prompt_hidden_dim=hidden_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.MSELoss()

    for epoch in range(1, EPOCHS + 1):
        model.train()
        total_loss = 0.0
        for probs_batch, prompts_batch in train_loader:
            probs_batch = probs_batch.to(device).float()
            prompts_batch = prompts_batch.to(device).float().view(probs_batch.size(0), prompt_len, hidden_dim)
            optimizer.zero_grad()
            pred_prompts, _ = model(probs_batch)
            loss = criterion(pred_prompts, prompts_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * probs_batch.size(0)
        avg_loss = total_loss / len(train_ds)
        print(f"Epoch {epoch}/{EPOCHS} TrainLoss={avg_loss:.6e}")

    results = {}
    for victim_name, (victim_model, victim_tok) in victim_adapters.items():
        if victim_name in stolen_adapters:
            continue
        teacher_prompt = get_prompt_embedding_cpu(victim_model)
        feats = get_first_token_probs(victim_model, victim_tok, victim_name, NUM_SAMPLES_PER_TASK, batch_size=BATCH_SIZE)
        with torch.no_grad():
            recon_prompt = model(feats.to(device).float())[0].mean(dim=0).cpu()
        random_prompt = torch.randn_like(teacher_prompt)
        teacher_metrics = eval_with_prompt(teacher_prompt, victim_name, tokenizer, device=device)
        recon_metrics = eval_with_prompt(recon_prompt, victim_name, tokenizer, device=device)
        random_metrics = eval_with_prompt(random_prompt, victim_name, tokenizer, device=device)
        l2, cos_sim = compute_similarity_metrics(recon_prompt, teacher_prompt)
        results[victim_name] = {
            "Teacher": teacher_metrics,
            "Reconstructed": recon_metrics,
            "Random": random_metrics,
            "L2": l2,
            "Cosine": cos_sim,
        }

    print(f"Stolen adapters used for training: {', '.join([n.upper() for n in stolen_adapters.keys()])}")
    print(f"Evaluated victim-only adapters: {', '.join([n.upper() for n in results.keys()])}\n")
    for name, metrics in results.items():
        print(f"{name.upper()} RESULTS:")
        if name in ["cnn", "arxiv"]:
            teacher_rouge = metrics["Teacher"]["rouge"]
            recon_rouge = metrics["Reconstructed"]["rouge"]
            random_rouge = metrics["Random"]["rouge"]
            print(f"Teacher ROUGE: {teacher_rouge}, Reconstructed ROUGE: {recon_rouge}, Random ROUGE: {random_rouge}")
            print(f"Teacher BLEU: {metrics['Teacher']['bleu']}, Reconstructed BLEU: {metrics['Reconstructed']['bleu']}, Random BLEU: {metrics['Random']['bleu']}")
        else:
            print(f"Teacher ACC/F1: {metrics['Teacher']}, Reconstructed ACC/F1: {metrics['Reconstructed']}, Random ACC/F1: {metrics['Random']}")
            print(f"L2: {metrics['L2']:.6e}, Cosine: {metrics['Cosine']:.6f}")
        print("")

if __name__ == "__main__":
    main()
