import os
import sys
import argparse
import logging
import random
import pickle
from pathlib import Path
from copy import deepcopy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import gc
from tqdm import tqdm
from torch.nn import Parameter
from torch.utils.data import DataLoader, Subset

from transformers import (
    T5ForSequenceClassification,
    T5Tokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    default_data_collator
)
from datasets import load_dataset, Dataset
import evaluate

try:
    import wandb
    os.environ["WANDB__SERVICE_WAIT"] = "800"
except ImportError:
    wandb = None

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# HF cache directory
hf_cache_dir = "./hf_cache/"

# T5 model mappings
MODELS = {
    "t5": "google-t5/t5-base",
    "t5flan": "google/flan-t5-base",
    "t5-1-1": "google/t5-v1_1-base",
    "mt5": "google/mt5-base",
}

# Dataset configurations
DATASETS = {
    "MNLI": {
        "dataset_name": "glue",
        "dataset_config": "mnli",
        "num_labels": 3,
        "premise_key": "premise",
        "hypothesis_key": "hypothesis",
        "label_key": "label",
        "eval_key": "validation_matched",
        "train_key": "train",
        "label_type": "int",
        "label_mapping": None

    },
    "SNLI": {
        "dataset_name": "stanfordnlp/snli",
        "dataset_config": None,
        "num_labels": 3,
        "premise_key": "premise",
        "hypothesis_key": "hypothesis",
        "label_key": "label",
        "eval_key": "validation",
        "train_key": "train",
        "label_type": "int",
        "label_mapping": None

    },
    "QNLI": {
        "dataset_name": "glue",
        "dataset_config": "qnli",
        "num_labels": 2,
        "premise_key": "question",
        "hypothesis_key": "sentence",
        "label_key": "label",
        "eval_key": "validation",
        "train_key": "train",
        "label_type": "int",
        "label_mapping": None

    },
    "RTE": {
        "dataset_name": "glue",
        "dataset_config": "rte",
        "num_labels": 2,
        "premise_key": "sentence1",
        "hypothesis_key": "sentence2",
        "label_key": "label",
        "eval_key": "validation",
        "train_key": "train",
        "label_type": "int",
        "label_mapping": None

    },
    "SCITAIL": {
        "dataset_name": "scitail",
        "dataset_config": "dgem_format",
        "num_labels": 2,
        "premise_key": "premise",
        "hypothesis_key": "hypothesis",
        "label_key": "label",
        "eval_key": "validation",
        "train_key": "train",
        "num_epochs": 50,
        "label_type": "string",
        "label_mapping": {"entails": 1, "neutral": 0}  # Map string labels to integers
    }
    
}


class T5TaskVector:
    """Task vector implementation for T5 models."""
    
    def __init__(self, pretrained_model=None, finetuned_model=None, vector=None):
        if vector is not None:
            self.vector = vector
        else:
            self.vector = self._compute_task_vector(pretrained_model, finetuned_model)
    
    def _compute_task_vector(self, pretrained_model, finetuned_model):
        """Compute task vector as difference between fine-tuned and pretrained weights."""
        pretrained_dict = pretrained_model.state_dict()
        finetuned_dict = finetuned_model.state_dict()
        
        vector = {}
        for key in pretrained_dict.keys():
            if key in finetuned_dict:
                vector[key] = finetuned_dict[key] - pretrained_dict[key]
        
        return vector
    
    def apply_to_model(self, model, alpha=1.0, layers_to_skip=["xxxxxxxxxxxxxxxxxxx"]):
        """Apply task vector to a model with given scaling factor."""
        model_dict = model.state_dict()
        new_dict = {}
        
        for key in model_dict.keys():     
            if key in self.vector:
                if any(layer in key for layer in layers_to_skip):
                        print(f'Skipping layer {key}')
                        new_dict[key] = model_dict[key]
                else:
                    new_dict[key] = model_dict[key] + alpha * self.vector[key]
            else:
                new_dict[key] = model_dict[key]
        
        model.load_state_dict(new_dict)
        return model


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="T5 Gradient Sign Masking")
    
    # Dataset and model arguments
    parser.add_argument("--dataset", type=str, required=True, choices=list(DATASETS.keys()),
                       help="Dataset to use for evaluation")
    parser.add_argument("--model_a", type=str, required=True, choices=list(MODELS.keys()),
                       help="Source model identifier")
    parser.add_argument("--model_b", type=str, required=True, choices=list(MODELS.keys()),
                       help="Target model identifier")
    parser.add_argument("--ft_root", type=str, default="./ft/",
                       help="Root directory containing fine-tuned models")
    parser.add_argument("--output_root", type=str, default="./results/",
                       help="Output directory for results")
    
    # Gradient computation arguments
    parser.add_argument("--real_samples_per_class", type=int, default=None,
                       help="Number of samples per class for gradient computation")
    parser.add_argument("--num_batches", type=int, default=None,
                       help="Number of batches to use for gradient computation")
    parser.add_argument("--batch_size", type=int, default=16,
                       help="Batch size for training and evaluation")
    parser.add_argument("--max_input_length", type=int, default=256,
                       help="Maximum input sequence length")
    
    # Training arguments
    parser.add_argument("--optimize_during_realgrad", action="store_true",
                       help="Optimize model B during gradient sign computation")
    parser.add_argument("--sign_mode", type=str, default="mean", choices=["mean", "max"],
                       help="Method to compute gradient signs")
    parser.add_argument("--learning_rate", type=float, default=1e-5,
                       help="Learning rate for optimization during gradient computation")
    
    # Evaluation arguments
    parser.add_argument("--eval_alphas", type=int, default=10,
                       help="Number of alpha values to evaluate")
    parser.add_argument("--seed", type=int, default=42,
                       help="Random seed")
    
    # Logging
    parser.add_argument("--wandb_mode", type=str, default="disabled", choices=["online", "offline", "disabled"],
                       help="Weights and Biases logging mode")
    parser.add_argument("--wandb_group", type=str, default="t5",
                       help="Weights and Biases group name")
    
    
    parser.add_argument("--evaluate_only_realgrad", action="store_true",
                       help="Evaluate only the real gradient task vector")
    parser.add_argument("--evaluate_base_models", action="store_true",
                       help="Evaluate base models without task vectors")
    parser.add_argument("--use_aft", action="store_true",
                       help="Use the A model finetuned for grad eval")
    parser.add_argument("--use_bft", action="store_true",
                       help="Use the B model finetuned for grad eval") 
    
    
    parser.add_argument("--only_head", action="store_true",
                       help="Apply gradient sign masking only to classification head")
    parser.add_argument("--only_bb", action="store_true",
                       help="Apply gradient sign masking only to backbone (transformer) layers")
    
    
    return parser.parse_args()

def convert_label(label, dataset_config):
    """Convert label based on dataset configuration"""
    if dataset_config["label_type"] == "string":
        if dataset_config["label_mapping"] and label in dataset_config["label_mapping"]:
            return dataset_config["label_mapping"][label]
        else:
            return -1
    else:
        try:
            return int(label)
        except (ValueError, TypeError):
            return -1
        
        
def set_seed(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def load_t5_model(model_path, num_labels, device):
    """Load T5 model for sequence classification."""
    model = T5ForSequenceClassification.from_pretrained(
        model_path,
        num_labels=num_labels,
        torch_dtype=torch.float32,
        cache_dir=hf_cache_dir,
    )
    tokenizer = T5Tokenizer.from_pretrained(model_path, cache_dir=hf_cache_dir)
    model.to(device)
    return model, tokenizer

def load_dataset_splits(dataset_name, tokenizer, max_input_length=256):
    """Load and preprocess dataset splits."""
    dataset_config = DATASETS[dataset_name]
    
    dataset = load_dataset(
        dataset_config["dataset_name"],
        dataset_config["dataset_config"],
        cache_dir=hf_cache_dir,
    )
    
    def filter_valid_labels(example):
        label = convert_label(example[dataset_config["label_key"]], dataset_config)
        return label >= 0 and label < dataset_config["num_labels"]
    
    print(f"Original dataset sizes:")
    print(f"  Train: {len(dataset[dataset_config['train_key']])}")
    print(f"  Validation: {len(dataset[dataset_config['eval_key']])}")
    
    print(f"\nSample labels from training data:")
    for i, example in enumerate(dataset[dataset_config["train_key"]]):
        if i >= 5:  # Show first 5 examples
            break
        original_label = example[dataset_config["label_key"]]
        converted_label = convert_label(original_label, dataset_config)
        print(f"  Example {i}: '{original_label}' -> {converted_label}")
    
    dataset = dataset.filter(filter_valid_labels)
    
    print(f"\nAfter filtering invalid labels:")
    print(f"  Train: {len(dataset[dataset_config['train_key']])}")
    print(f"  Validation: {len(dataset[dataset_config['eval_key']])}")
    
    if len(dataset[dataset_config["train_key"]]) == 0:
        raise ValueError(f"No valid training samples found for {dataset_name}. Check label mapping.")
    
    def preprocess_function(examples):
        """Preprocess examples for T5."""
        if dataset_name in ["SNLI", "MNLI"]:
            inputs = [
                f"nli premise: {p} hypothesis: {h}"
                for p, h in zip(examples[dataset_config["premise_key"]], 
                               examples[dataset_config["hypothesis_key"]])
            ]
        elif dataset_name == "QNLI":
            inputs = [
                f"qnli question: {q} sentence: {s}"
                for q, s in zip(examples[dataset_config["premise_key"]], 
                               examples[dataset_config["hypothesis_key"]])
            ]
        elif dataset_name == "RTE":
            inputs = [
                f"rte sentence1: {s1} sentence2: {s2}"
                for s1, s2 in zip(examples[dataset_config["premise_key"]], 
                                 examples[dataset_config["hypothesis_key"]])
            ]
        elif dataset_name == "SCITAIL":
            inputs = [
                f"nli premise: {p} hypothesis: {h}"  # Use same format as NLI
                for p, h in zip(examples[dataset_config["premise_key"]], 
                               examples[dataset_config["hypothesis_key"]])
            ]
        else:
            raise ValueError(f"Unknown dataset format for {dataset_name}")
        
        model_inputs = tokenizer(
            inputs,
            max_length=max_input_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        labels = [convert_label(label, dataset_config) for label in examples[dataset_config["label_key"]]]
        model_inputs["labels"] = labels
        return model_inputs
    
    train_dataset = dataset[dataset_config["train_key"]].map(
        preprocess_function,
        batched=True,
        remove_columns=dataset[dataset_config["train_key"]].column_names,
    )
    
    eval_dataset = dataset[dataset_config["eval_key"]].map(
        preprocess_function,
        batched=True,
        remove_columns=dataset[dataset_config["eval_key"]].column_names,
    )
    
    return train_dataset, eval_dataset, dataset_config

def build_class_indices(dataset):
    """Build dictionary mapping class labels to sample indices."""
    per_class = {}
    for idx in range(len(dataset)):
        label = dataset[idx]["labels"]
        per_class.setdefault(int(label), []).append(idx)
    return per_class


def sample_indices_per_class(dataset, k):
    """Sample k indices per class from dataset."""
    per_class = build_class_indices(dataset)
    sampled = {}
    for cls, idxs in per_class.items():
        if len(idxs) < k:
            logger.warning(f"Class {cls} has only {len(idxs)} samples, requested {k}. Using all.")
            sampled[cls] = idxs
        else:
            sampled[cls] = random.sample(idxs, k)
    return sampled

def build_realgrad_dataloader(train_dataset, base_loader, args, tokenizer, model, sample_indices=None):
    """Build dataloader for real gradient computation."""
    
    
    first_sample = train_dataset[0]
    
    print("First sample:", first_sample)
    
    if isinstance(first_sample["labels"], (list, tuple)):
        collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    else:
        collator = default_data_collator
        
        
    if sample_indices is not None:
        flat_indices = []
        for cls, idxs in sample_indices.items():
            flat_indices.extend(idxs)
        subset = Subset(train_dataset, flat_indices)
        return DataLoader(
            subset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=0,   # T5 tokenization can be sensitive to multiprocessing
            pin_memory=True,
            collate_fn=collator,
        )
    elif args.num_batches is not None:
        class FirstNBatches:
            def __init__(self, dataloader, n):
                self.dataloader = dataloader
                self.n = n
            def __iter__(self):
                return itertools.islice(iter(self.dataloader), self.n)

        if base_loader.collate_fn is None:
            base_loader.collate_fn = collator
        return FirstNBatches(base_loader, args.num_batches)
    else:
        if base_loader.collate_fn is None:
            base_loader.collate_fn = collator
        return base_loader


def compute_real_gradient_signs(work_model, dataloader, device, optimize=False, 
                               learning_rate=1e-5, sign_mode="mean"):
    """Compute gradient signs for T5 model parameters."""
    work_model.to(device)
    work_model.eval()
    work_model.zero_grad()
    
    if optimize:
        work_model.train()
        optimizer = torch.optim.AdamW(work_model.parameters(), lr=learning_rate)
        optimizer.zero_grad()
    
    # Get trainable parameters
    trainable_params = [(name, p) for name, p in work_model.named_parameters() if p.requires_grad]
    param_dict = {name: p for name, p in trainable_params}
    
    total_steps = len(dataloader) if hasattr(dataloader, '__len__') else 100
    scale = 1.0 
    
    if sign_mode == "max":
        # For majority voting
        sign_sums = {name: torch.zeros_like(p, device=device) for name, p in trainable_params}
    
    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Computing gradient signs")):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        
        if sign_mode == "mean":
            # Compute loss and backprop
            outputs = work_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss * scale
            loss.backward()
            
        elif sign_mode == "max":
            # Compute per-sample gradients for majority voting
            outputs = work_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            losses = F.cross_entropy(outputs.logits.view(-1, outputs.logits.size(-1)), 
                                   labels.view(-1), reduction="none") * scale
            
            for i in range(len(losses)):
                if losses[i].item() > 0:  
                    grads = torch.autograd.grad(
                        losses[i],
                        [param_dict[name] for name, _ in trainable_params],
                        retain_graph=True,
                        create_graph=False,
                    )
                    
                    for (name, _), grad in zip(trainable_params, grads):
                        if grad is not None:
                            sign_sums[name] += torch.sign(-grad.detach())
            
            if optimize:
                total_loss = losses.mean()
                total_loss.backward()
        
        if hasattr(dataloader, '__len__') and batch_idx >= len(dataloader) - 1:
            break
    
    # Extract gradient signs
    gradient_signs = {}
    if sign_mode == "mean":
        for name, param in work_model.named_parameters():
            if param.grad is not None:
                gradient_signs[name] = torch.sign(-param.grad)
    elif sign_mode == "max":
        gradient_signs = {name: torch.sign(acc) for name, acc in sign_sums.items()}
    
    if optimize:
        optimizer.step()
        return gradient_signs, work_model
    else:
        return gradient_signs, None


def taskvector_from_gradient_signs(gradient_signs, reference_taskvector, fallback_taskvector=None, only_head = False, only_bb = False):
    """Build task vector from gradient signs using reference vector for magnitudes."""
    masked_vector = {}
    
    assert not (only_head and only_bb), "Cannot have both only_head and only_bb set to True."
    
    for key in reference_taskvector.vector.keys():
        if key in gradient_signs:
            
            if "classification_head" in key and only_bb:
                mask = torch.ones_like(reference_taskvector.vector[key], dtype=torch.bool)
            
            elif "transformer" in key and only_head:
                mask = torch.ones_like(reference_taskvector.vector[key], dtype=torch.bool)
                
            else:
                ref_sign = torch.sign(reference_taskvector.vector[key])
                grad_sign = gradient_signs[key]
                    
                mask = (ref_sign == grad_sign)
                
            masked_vector[key] = torch.where(mask, reference_taskvector.vector[key], 
                                           torch.zeros_like(reference_taskvector.vector[key]))
        else:
            if fallback_taskvector is not None and key in fallback_taskvector.vector:
                masked_vector[key] = fallback_taskvector.vector[key]
            else:
                masked_vector[key] = reference_taskvector.vector[key]
    
    return T5TaskVector(vector=masked_vector)


def sign_agreement(grad_signs_a, grad_signs_b):
        """Compute percentage of sign agreement between two gradient sign sets."""
        common_keys = set(grad_signs_a.keys()) & set(grad_signs_b.keys())
        agreement = 0
        total = 0
        
        per_key_agreement = {}
        
        for key in common_keys:
            key_agreement = (grad_signs_a[key] == grad_signs_b[key]).sum().item()
            key_total = grad_signs_a[key].numel()
            per_key_agreement[key] = key_agreement / key_total if key_total > 0 else 0
            
            agreement += key_agreement
            total += key_total
        
        final_agreement = agreement / total if total > 0 else 0
        logger.info(f"Sign agreement: {agreement}/{total} = {final_agreement*100:.2f}%")
        
        for key, key_agreement in per_key_agreement.items():
            logger.info(f"Sign agreement for {key}: {key_agreement*100:.2f}%")
        
        if wandb is not None:
            wandb.log({"sign_agreement": final_agreement})
            for key, key_agreement in per_key_agreement.items():
                wandb.log({f"sign_agreement_{key}": key_agreement})
        
        return final_agreement, per_key_agreement

def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    if isinstance(predictions, tuple):
        predictions = predictions[0]  
    predictions = np.argmax(predictions, axis=-1)
    return {"accuracy": evaluate.load("accuracy").compute(predictions=predictions, references=labels)["accuracy"]}

def evaluate_model_with_taskvector(base_model, task_vector, eval_dataset, 
                                  alpha, device, batch_size=16):
    """Evaluate model with applied task vector."""
    eval_model = deepcopy(base_model)
    
    # Apply task vector
    eval_model = task_vector.apply_to_model(eval_model, alpha=alpha)
    eval_model.eval()
    
    # Create trainer for evaluation
    trainer = Trainer(
        model=eval_model,
        args=TrainingArguments(
            output_dir="./tmp_eval",
            per_device_eval_batch_size=batch_size,
            report_to=None,
        ),
        compute_metrics=compute_metrics,
        
    )
    
    results = trainer.evaluate(eval_dataset=eval_dataset)
    
    del eval_model
    torch.cuda.empty_cache()
    
    return results["eval_accuracy"], results["eval_loss"]


def evaluate_task_vectors(base_model, eval_dataset, device, alphas, task_vectors, logger):
    """Evaluate multiple task vectors across different alpha values."""
    results = {}
    
    for name, task_vector in task_vectors.items():
        logger.info(f"Evaluating {name}")
        best_acc = 0
        best_alpha = 0
        alpha_results = []
        
        for alpha in alphas:
            acc, loss = evaluate_model_with_taskvector(
                base_model, task_vector, eval_dataset, alpha, device
            )
            
            alpha_results.append({"alpha": alpha, "accuracy": acc, "loss": loss})
            
            if acc > best_acc:
                best_acc = acc
                best_alpha = alpha
            
            logger.info(f"  α={alpha:.2f}: acc={acc:.4f}, loss={loss:.4f}")
        
        results[name] = {
            "best_accuracy": best_acc,
            "best_alpha": best_alpha,
            "all_results": alpha_results
        }
        
        # Log to wandb
        if wandb is not None:
            wandb.log({
                f"{name}_best_accuracy": best_acc,
                f"{name}_best_alpha": best_alpha,
                f"{name}_evaluation": wandb.Table(
                    data=[[r["alpha"], r["accuracy"], r["loss"]] for r in alpha_results],
                    columns=["Alpha", "Accuracy", "Loss"]
                )
            })
    
    return results


def main():
    args = parse_args()
    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize wandb
    config = vars(args)
    wandb.init(
        project="T5_Gradient_Sign_Masking",
        config=config,
        mode=args.wandb_mode,
        group=args.wandb_group,
    )
    
    dataset_config = DATASETS[args.dataset]    
    
    
    # Define model paths
    model_a_base = MODELS[args.model_a]
    model_b_base = MODELS[args.model_b]
    model_a_ft = os.path.join(args.ft_root, f"{args.dataset}_{args.model_a}")
    model_b_ft = os.path.join(args.ft_root, f"{args.dataset}_{args.model_b}")

    logger.info(f"Loading dataset: {args.dataset}")
    
    # Load models
    logger.info("Loading models...")
    model_a, tokenizer_a = load_t5_model(model_a_base, dataset_config["num_labels"], device)
    model_b, tokenizer_b = load_t5_model(model_b_base, dataset_config["num_labels"], device)
    model_a_ft, _ = load_t5_model(model_a_ft, dataset_config["num_labels"], device)
    model_b_ft, _ = load_t5_model(model_b_ft, dataset_config["num_labels"], device)
    
    # Reprocess dataset with model B tokenizer
    train_dataset, eval_dataset, _ = load_dataset_splits(
        args.dataset, tokenizer_b, args.max_input_length
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    
    # Compute task vectors
    logger.info("Computing task vectors...")
    taskvector_a = T5TaskVector(model_a, model_a_ft)
    taskvector_b = T5TaskVector(model_b, model_b_ft)
    
    # Evaluate base models
    if args.evaluate_base_models:
        logger.info("Evaluating base models...")
        acc_a_base, loss_a_base = evaluate_model_with_taskvector(
            model_a, T5TaskVector(vector={}), eval_dataset, 0, device
        )
        acc_b_base, loss_b_base = evaluate_model_with_taskvector(
            model_b, T5TaskVector(vector={}), eval_dataset, 0, device
        )
        acc_a_ft, loss_a_ft = evaluate_model_with_taskvector(
            model_a_ft, T5TaskVector(vector={}), eval_dataset, 0, device
        )
        acc_b_ft, loss_b_ft = evaluate_model_with_taskvector(
            model_b_ft, T5TaskVector(vector={}), eval_dataset, 0, device
        )
        
        logger.info(f"Model A base: acc={acc_a_base:.4f}")
        logger.info(f"Model B base: acc={acc_b_base:.4f}")
        logger.info(f"Model A ft: acc={acc_a_ft:.4f}")
        logger.info(f"Model B ft: acc={acc_b_ft:.4f}")
        
        if wandb is not None:
            wandb.log({
                "model_a_base_accuracy": acc_a_base,
                "model_b_base_accuracy": acc_b_base,
                "model_a_ft_accuracy": acc_a_ft,
                "model_b_ft_accuracy": acc_b_ft,
            })
    # Determine sample indices for gradient computation
    sample_indices = None
    if args.real_samples_per_class is not None:
        sample_indices = sample_indices_per_class(train_dataset, args.real_samples_per_class)
        logger.info(f"Using {args.real_samples_per_class} samples per class for gradient computation")
    
    # Build gradient computation dataloader
    realgrad_loader = build_realgrad_dataloader(train_dataset, train_loader, args, tokenizer_b, model_b, sample_indices)
    
    print("Realgrad loader built.")
    
    # Compute gradient signs
    logger.info("Computing real gradient signs...")
    
    m = model_b
    if args.use_aft:
        m = model_a_ft
    if args.use_bft:
        m = model_b_ft
    
    real_gradient_signs, optimized_model = compute_real_gradient_signs(
        m,
        realgrad_loader,
        device,
        optimize=args.optimize_during_realgrad,
        learning_rate=args.learning_rate,
        sign_mode=args.sign_mode,
    )
    
    
    if optimized_model is not None: 
        acc_final, loss_final = evaluate_model_with_taskvector(
            optimized_model, T5TaskVector(vector={}), eval_dataset, 0, device
        )
        
        logger.info(f"Optimized model B during realgrad: acc={acc_final:.4f}, loss={loss_final:.4f}")
        if wandb is not None:
            wandb.log({
                "optimized_model_b_accuracy": acc_final,
                "optimized_model_b_loss": loss_final,
            })
    
    
     # Save results
    output_dir = Path(args.output_root)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    #export real gradient signs
    with open(output_dir / f"real_gradient_signs_{args.dataset}_{args.model_a}_to_{args.model_b}_{args.real_samples_per_class}.pkl", "wb") as f:
        pickle.dump(real_gradient_signs, f)
    
    # Compute oracle gradient signs (from model B task vector)
    oracle_gradient_signs = {k: torch.sign(v) for k, v in taskvector_b.vector.items()}
    
    # Compute sign agreement
    sign_agreement(oracle_gradient_signs, real_gradient_signs)
    
    # Create masked task vectors
    logger.info("Creating masked task vectors...")
    realgrad_taskvector_a = taskvector_from_gradient_signs(
        real_gradient_signs, taskvector_a, fallback_taskvector=taskvector_a, only_head = args.only_head, only_bb = args.only_bb
    )
    oracle_taskvector_a = taskvector_from_gradient_signs(
        oracle_gradient_signs, taskvector_a, fallback_taskvector=taskvector_a, only_head = args.only_head, only_bb = args.only_bb
    )
    
    # Evaluate task vectors
    alphas = [0.08, 0.05, 0.03, 0.01] + list(np.linspace(0.1, 2, 20))
    
    task_vectors = {
        "realgrad_taskvector_a": realgrad_taskvector_a,
        "taskvector_a": taskvector_a,
        "oracle_taskvector_a": oracle_taskvector_a,
    }
    
    if args.evaluate_only_realgrad:
        task_vectors = {
            "realgrad_taskvector_a": realgrad_taskvector_a,
        }
    
    logger.info("Evaluating task vectors...")
    results = evaluate_task_vectors(
        m, eval_dataset, device, alphas, task_vectors, logger
    )
    
    # Print final results
    logger.info("\nFinal Results:")
    logger.info("=" * 50)
    for name, result in results.items():
        logger.info(f"{name}: {result['best_accuracy']:.4f} @ alpha={result['best_alpha']:.2f}")
    
    with open(output_dir / f"results_{args.dataset}_{args.model_a}_to_{args.model_b}.pkl", "wb") as f:
        pickle.dump(results, f)
    
    logger.info(f"Results saved to {output_dir}")
    
    if wandb is not None:
        wandb.finish()


if __name__ == "__main__":
    main()