from data import NegativeImageFolder, COCOPositive
import utils
import data 
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import get_cosine_schedule_with_warmup
from tqdm import tqdm
import itertools
from torch.utils.data import DataLoader
from PIL import Image
from torch.optim import Adam
import random 
import torch
import torch 
from typing import Dict,List
from pathlib import Path
import gc
from transfer_eval import run_multi
from pope_eval import evaluate, ADAPTERS

#configs
model_name = "llava" # llava or qwen
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff", ".tif"}
cache_path = "cache"
coco_dir = "../PathtoCOCO/COCO"
negative_image_folder = "logs/attack/llava/{cls}_lr=0.1_steps=100_threshold=0.8_num_generation=4_guidance_scale=5.0_lambda_contrast=15.0_lambda_reg=10.0_OD_threshold=0.5__sort=True_t=30_num_inf=50_deep_llava_bs=64_lr=0_0002_epochs=10_context_dim=4096_hidden_dim=1024/images"
#negative_image_folder = "logs/attack/qwen/{cls}_lr=0.1_steps=100_threshold=0.8_num_generation=4_guidance_scale=5.0_lambda_contrast=15.0_lambda_reg=10.0_OD_threshold=0.5__sort=True_t=30_num_inf=50_qwen_bs=32_lr=0_0002_epochs=10_context_dim=4096_hidden_dim=2048/images"
lora_rank= 8
positive_image_folder = "coco_generate_with_prompt_sorted/{cls}"

lora_alpha= 32
lora_dropout= 0.05
lora_bias = "none"
EPOCHS = 5
BATCH_SIZE = 16 # Adjust based on your GPU memory
LEARNING_RATE = 2e-5
WARMUP_RATIO = 0.1
device = "cuda"
image_dir = "logs/attack/llava/{cls}_lr=0.1_steps=100_threshold=0.8_num_generation=4_guidance_scale=5.0_lambda_contrast=15.0_lambda_reg=10.0_OD_threshold=0.5__sort=True_t=30_num_inf=50_deep_llava_bs=64_lr=0_0002_epochs=10_context_dim=4096_hidden_dim=1024/images"
#image_dir = "logs/attack/qwen/{cls}_lr=0.1_steps=100_threshold=0.8_num_generation=4_guidance_scale=5.0_lambda_contrast=15.0_lambda_reg=10.0_OD_threshold=0.5__sort=True_t=30_num_inf=50_qwen_bs=32_lr=0_0002_epochs=10_context_dim=4096_hidden_dim=2048/images"
def pil_collate_fn(batch):
    """
    Collates a batch of data, keeping PIL images as a list.
    This prevents the default collate from trying to stack them into a tensor.
    """
    images = [item[0] for item in batch]
    object_names = [item[1] for item in batch]
    return images, object_names

def get_messsage(prompt):
    messages =[
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": prompt}
            ]}
    ]
    return messages

def get_image_paths(root: str) -> List[str]:
    p = Path(root)
    if p.is_file() and p.suffix.lower() in IMG_EXTS:
        return [str(p)]
    if not p.exists():
        return []
    return [str(fp) for fp in p.rglob("*") if fp.suffix.lower() in IMG_EXTS]
def compute_summary(rows: List[Dict[str, float]]) -> Dict[str, float]:
    if not rows:
        return {"yes_share": 0.0, "yes_probability_avg": 0.0, "len": 0}
    yes_count = sum(1 for r in rows if r.get("predicted_answer") == "yes")
    yes_share = yes_count / len(rows)
    yes_prob_avg = sum(r.get("yes_probability", 0.0) for r in rows) / len(rows)
    return {"yes_share": round(yes_share, 4),
            "yes_probability_avg": round(yes_prob_avg, 4),
            "len": len(rows)}

def _build_yes_no_id_lists(tokenizer) -> (List[int], List[int]):
    """Collect token ids that decode (after stripping) to 'yes' or 'no'."""
    yes_ids, no_ids = [], []
    vocab_size = tokenizer.vocab_size
    for tok_id in range(vocab_size):
        token_text = tokenizer.decode(tok_id).strip().strip(".").lower()
        if token_text == "yes":
            yes_ids.append(tok_id)
        elif token_text == "no":
            no_ids.append(tok_id)
    # fallback: short variants containing yes/no
    if not yes_ids:
        for tok_id in range(vocab_size):
            t = tokenizer.decode(tok_id).lower().strip()
            if "yes" in t and len(t) <= 4:
                yes_ids.append(tok_id)
    if not no_ids:
        for tok_id in range(vocab_size):
            t = tokenizer.decode(tok_id).lower().strip()
            if "no" in t and len(t) <= 3:
                no_ids.append(tok_id)
    return yes_ids, no_ids

def eval_one_class(
    model,
    processor,  # any adapter with yes_no_probabilities()
    yes_id,
    no_id,
    images_dir: str,
    object_name: str,
    silent: bool
) -> Dict[str, object]:
    image_paths = get_image_paths(images_dir)
    model.eval()
    if not image_paths:
        if not silent:
            print(f"[WARN] No images found for class '{object_name}' in: {images_dir}")
        return {"summary": {"yes_share": 0.0, "yes_probability_avg": 0.0, "len": 0}, "rows": []}

    rows = []
    question = f"Is there a {object_name} in this image?"

    for idx, img_path in enumerate(sorted(image_paths)):
        if idx == 50:
            break
        probs = yes_no_probabilities(model,processor,yes_id,no_id,img_path, question)
        if "error" in probs:
            if not silent:
                print(f"[{idx+1}/{len(image_paths)}] {img_path} -> ERROR: {probs['error']}")
            continue

        if probs["yes"] >= probs["no"]:
            pred = "yes"
            conf = probs["yes"]
        else:
            pred = "no"
            conf = probs["no"]

        row = {
            "image_path": img_path,
            "predicted_answer": pred,
            "confidence": conf,
            "yes_probability": probs["yes"],
            "no_probability": probs["no"]
        }
        rows.append(row)

        if not silent:
            print(f"[{idx+1}/{len(image_paths)}] {img_path} -> yes={probs['yes']:.4f}, no={probs['no']:.4f}, pred='{pred}' ({conf:.2%})")

    summary = compute_summary(rows)
    print("nextobj"*20)
    if not silent:
        print(f"Summary for '{object_name}': yes_share={summary['yes_share']}, "
              f"yes_probability_avg={summary['yes_probability_avg']}, len={summary['len']}")
    return {"summary": summary, "rows": rows}

def yes_no_probabilities(model,processor,yes_id,no_id, image_path: str, question: str, max_new_tokens: int = 1) -> Dict[str, float]:
    try:
        img = Image.open(image_path).convert("RGB")
    except Exception as e:
        return {"error": f"Could not open image: {e}"}

    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": img},
            {"type": "text", "text": f"{question} Please respond with only 'Yes' or 'No'."},
        ],
    }]

    text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text_prompt], images=[img], return_tensors="pt", padding=True).to(model.device)
    logits = model(**inputs).logits.float()
    logits_step = logits[:, -1, :]
    probs = torch.softmax(logits_step, dim=-1)[0]
    print("probs",probs.shape)
    print("yes_id,no_id",yes_id,no_id)
    no_prob  = float(sum(probs[i].item() for i in no_id))
    yes_prob = float(sum(probs[i].item() for i in yes_id))
    print("logit", yes_prob,no_prob)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        output_scores=True,
        return_dict_in_generate=True
    )

    if not outputs.scores:
        return {"error": "No scores returned from model generation."}

    scores = outputs.scores[0][0]
    probs = torch.nn.functional.softmax(scores, dim=-1)

    yes_prob = float(sum(probs[i].item() for i in yes_id))
    no_prob  = float(sum(probs[i].item() for i in no_id))
    return {"yes": yes_prob, "no": no_prob}
def main():
    # get model
    model, processor = utils.get_model(model_name, cache_path=cache_path,token="")
    model = model.to(device)
    if getattr(model.config, "use_cache", None):
        model.config.use_cache = False
    #negative data
    neg_data = NegativeImageFolder(negative_image_folder)
    neg_loader = DataLoader(neg_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pil_collate_fn)
    
    #positive data
    pos_data = COCOPositive(positive_image_folder, object_names=neg_data.object_names)
    pos_loader = DataLoader(pos_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pil_collate_fn)
    
    
    
    print(f"negative data classes:{neg_data.object_names}, len:{len(neg_data)}")
    print(f"positive data len:{len(pos_data)}")
    
    #lora
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha, 
        lora_dropout=lora_dropout,
        bias=lora_bias,
        target_modules = [
            # Vision Encoder Attention Layers
            "qkv",
            "proj",
            
            # Language Model Attention Layers
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
    
            # MLP layers for both language and vision
            "gate_proj", 
            "up_proj",
            "down_proj"
    
        ],
        task_type="CAUSAL_LM" # Specify the task type
    )
    
    lora_model = get_peft_model(model, lora_config)
    lora_model.gradient_checkpointing_enable()
    lora_model.print_trainable_parameters()
    lora_model.enable_input_require_grads()
    

    #optimizer
    optimizer = Adam(lora_model.parameters(), lr=LEARNING_RATE)
    #lora_model = PeftModel.from_pretrained(lora_model, "./qwen-lora-finetuned-manual", torch_dtype=torch.float16)
    
    steps_per_epoch = min(len(pos_loader), len(neg_loader))
    total_training_steps = steps_per_epoch * EPOCHS
    warmup_steps = int(WARMUP_RATIO * total_training_steps)
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_training_steps
    )
    
    # --- 5. Get Target Token IDs ---
    best_acc = 0.0
    yes_token_id = processor.tokenizer.encode("Yes", add_special_tokens=False)[0]
    no_token_id = processor.tokenizer.encode("No", add_special_tokens=False)[0]
    yes_ids, no_ids = _build_yes_no_id_lists(processor.tokenizer)

    # finetune the model 
    tamplates = utils.get_prompt_templates()
    
    for epoch in range(EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---")
        
        # Use tqdm for a progress bar, iterating through the shorter loader
        pbar = tqdm(zip(pos_loader, neg_loader), total=steps_per_epoch)
        
        total_loss_epoch = 0.0
    
        for step, (pos_batch, neg_batch) in enumerate(pbar):
            lora_model.train()
            optimizer.zero_grad()
            #print(pos_batch, neg_batch)
            loss = 0.0
            loss_detail ={}
            for batch, token_id, update_name, update in [(pos_batch, yes_token_id, "positive", False), (neg_batch, no_token_id, "negative", True)]:
                # --- Process Positive Batch ---
                #print(update_name)
                images, object_names = batch
                
                prompts = [random.choice(tamplates).format(obj=name) for name in object_names]
                #print(prompts)

                for i in range(len(images)):
                    if model_name == "pali":
                        inputs = processor(text=prompts[i], images=images[i], return_tensors="pt", padding=True).to('cuda')
                    else:
                        inputs = utils.vllm_standard_preprocessing(processor, prompts[i],images[i])
                    
                    logits = lora_model(**inputs).logits.float()
                    logits_step = logits[:, -1, :]
                    probs = torch.softmax(logits_step, dim=-1)
                    prob_target = probs[0, token_id]
                    print(f"probability of target token ({'Yes' if token_id==yes_token_id else 'No'}): {prob_target.item():.4f}")
                    loga_prob = -torch.log(prob_target + 1e-8)
                    loss += loga_prob

                
                
                loss_detail[update_name]=loss.item()
                total_loss_epoch += loss.item()
                print(type(loss), loss.requires_grad, loss.grad_fn)

                #print("PASS "*40)
                if update:
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    
                    # Update progress bar
                    pbar.set_description(f'Loss: {loss_detail["positive"]+loss_detail["negative"]:.4f} (Pos: {loss_detail["positive"]:.4f}, Neg: {loss_detail["negative"]:.4f})')
        



        avg_loss = total_loss_epoch / steps_per_epoch
        # --- 7. Save Final LoRA Weights ---
        output_dir = f"./Finetuned_{model_name}{EPOCHS}2/lora-finetuned-manual{epoch}"
        lora_model.eval()
        lora_model.save_pretrained(output_dir)
        processor.save_pretrained(output_dir)
        adapter = ADAPTERS[model_name](lora=True,lorapath=output_dir,device='cuda')
        stats = evaluate(Path('POPE/output_train/coco'), Path('../PathtoCOCO/COCO/images/train2017'), 'random', adapter, 1000)
        print(f"POPE [{'random'}] on {model_name}  n={stats['n']}")
        print(f"Accuracy:  {stats['accuracy']:.4f}")
        print(f"Precision: {stats['precision']:.4f}")
        print(f"Recall:    {stats['recall']:.4f}")
        print(f"F1:        {stats['f1']:.4f}   (primary metric in paper)")
        print(f"Yes-rate:  {stats['yes_rate']:.4f}")
        if stats['accuracy'] > best_acc:
            best_acc = stats['accuracy']
            lora_model.save_pretrained(f"./Finetuned_{model_name}{EPOCHS}2/lora-finetuned-best")
            processor.save_pretrained(f"./Finetuned_{model_name}{EPOCHS}2/lora-finetuned-best")
            print(f"Best model saved with accuracy: {best_acc:.4f}")

        print(f"\n✅ Fine-tuning complete. LoRA weights saved to {output_dir}")

        
    
        print(f"Average loss for epoch {epoch+1}: {avg_loss:.4f}")
        
    print("\n--- Fine-tuning complete ---","best model on validation set:")
    adapter = ADAPTERS[model_name](lora=True,lorapath=f"./Finetuned_{model_name}{EPOCHS}2/lora-finetuned-best",device='cuda')
    stats = evaluate(Path('POPE/output/coco'), Path('../PathtoCOCO/COCO/images/val2017'), 'random', adapter, 1000)
    print(f"POPE [{'random'}] on {model_name}  n={stats['n']}")
    print(f"Accuracy:  {stats['accuracy']:.4f}")
    print(f"Precision: {stats['precision']:.4f}")
    print(f"Recall:    {stats['recall']:.4f}")
    print(f"F1:        {stats['f1']:.4f}   (primary metric in paper)")
    print(f"Yes-rate:  {stats['yes_rate']:.4f}")


    
import argparse
from typing import Set

# ========== Argument Parsing ==========

def parse_img_exts(exts_str: str) -> Set[str]:
    """Convert comma-separated exts to a set (ensure leading dot)."""
    exts = {e.strip() for e in exts_str.split(",") if e.strip()}
    return {e if e.startswith(".") else f".{e}" for e in exts}

def get_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        description="Training / attack script - CLI for hyperparameters and paths."
    )

    # model and device
    p.add_argument("--model_name", choices=["llava", "qwen"], default="llava")
    p.add_argument("--device", choices=["cuda", "cpu"], default="cuda")

    # image extensions
    p.add_argument("--img_exts",
                   type=str,
                   default=".jpg,.jpeg,.png,.bmp,.webp,.tiff,.tif",
                   help="Comma-separated list of extensions")

    # folders / paths
    p.add_argument("--negative_image_folder", type=str,
                   default=("logs/attack/llava/{cls}_lr=0.1_steps=100_threshold=0.8_"
                            "num_generation=4_guidance_scale=5.0_lambda_contrast=15.0_"
                            "lambda_reg=10.0_OD_threshold=0.5__sort=True_t=30_num_inf=50_"
                            "deep_llava_bs=64_lr=0_0002_epochs=10_context_dim=4096_hidden_dim=1024/images"))
    p.add_argument("--positive_image_folder", type=str,
                   default="coco_generate_with_prompt_sorted/{cls}")
    p.add_argument("--image_dir", type=str,
                   default=("logs/attack/llava/{cls}_lr=0.1_steps=100_threshold=0.8_"
                            "num_generation=4_guidance_scale=5.0_lambda_contrast=15.0_"
                            "lambda_reg=10.0_OD_threshold=0.5__sort=True_t=30_num_inf=50_"
                            "deep_llava_bs=64_lr=0_0002_epochs=10_context_dim=4096_hidden_dim=1024/images"))

    # LoRA / model config
    p.add_argument("--lora_rank", type=int, default=8)
    p.add_argument("--lora_alpha", type=int, default=32)
    p.add_argument("--lora_dropout", type=float, default=0.05)
    p.add_argument("--lora_bias", type=str, default="none")

    # training
    p.add_argument("--epochs", type=int, default=5)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--learning_rate", type=float, default=1e-6)
    p.add_argument("--warmup_ratio", type=float, default=0.1)

    return p

# ========== Apply to Globals ==========

def load_args_to_globals():
    global model_name, IMG_EXTS, negative_image_folder, positive_image_folder
    global lora_rank, lora_alpha, lora_dropout, lora_bias
    global EPOCHS, BATCH_SIZE, LEARNING_RATE, WARMUP_RATIO, device, image_dir

    parser = get_parser()
    args = parser.parse_args()

    # assign to globals
    model_name = args.model_name
    device = args.device
    IMG_EXTS = parse_img_exts(args.img_exts)
    negative_image_folder = args.negative_image_folder
    positive_image_folder = args.positive_image_folder
    image_dir = args.negative_image_folder

    lora_rank = args.lora_rank
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    lora_bias = args.lora_bias

    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size
    LEARNING_RATE = args.learning_rate
    WARMUP_RATIO = args.warmup_ratio

# ========== Example ==========

if __name__ == "__main__":
    load_args_to_globals()
    print("model_name:", model_name)
    print("IMG_EXTS:", IMG_EXTS)
    print("negative_image_folder:", negative_image_folder)
    print("positive_image_folder:", positive_image_folder)
    print("image_dir:", image_dir)
    print("lora_rank:", lora_rank, "lora_alpha:", lora_alpha)
    print("EPOCHS:", EPOCHS, "BATCH_SIZE:", BATCH_SIZE, "LR:", LEARNING_RATE)
    main()
