import os
import sys
import argparse
import logging
import random
import pickle
from pathlib import Path
from copy import deepcopy
import math
import clip 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import open_clip
import itertools

from utils import set_seed, get_models, parse_arguments, load_dataset, evaluate_model, evaluate_with_task_vector
from src.models import OpenCLIPModel
from task_vectors.src.task_vectors import TaskVector
from permutations.permutation_spec import CLIP_Visual_PermutationSpecBuilder
from permutations.weights_matcher import WeightMatcher, LayerIterationOrder
from permutations.utils import apply_permutation_to_statedict
from synth_image import SyntheticImage, PRECOMPUTED_DIR

try:
    import wandb
    os.environ["WANDB__SERVICE_WAIT"] = "800"
except ImportError:
    wandb = None

# torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)


def parse_local_args(argv=None):
    """Parse custom args, then delegate the rest to utils.parse_arguments().

    This extracts only the arguments specific to this script (e.g., --real_imgs_per_class,
    --precomputed_indices and passes the remaining CLI args
    to the repository's standard parser.
    """
    argv = list(sys.argv[1:] if argv is None else argv)

    local = argparse.ArgumentParser(add_help=False)
    group = local.add_mutually_exclusive_group(required=False)
    group.add_argument("--real_imgs_per_class", type=int, default=None,
                       help="Randomly sample this many images per class from the training set to compute real gradient signs (mutually exclusive with --precomputed_indices and --num_batches). If omitted, falls back to dataloader mode.")
    group.add_argument("--precomputed_indices", type=str, default=None,
                       help="Either a .pkl path or a spec 'coreset:K' | 'herding:K' | 'k-medoid:K'. If given, only these indices are used (mutually exclusive with --real_imgs_per_class and --num_batches).")
    group.add_argument("--num_batches", type=int, default=None,
                       help="Use dataloader mode and backprop this many batches for gradient-sign computation (mutually exclusive with --real_imgs_per_class and --precomputed_indices). If omitted and no other option is set, defaults to 1 batch.")
    # nuovo flag per ottimizzare durante il calcolo dei segni
    local.add_argument("--optimize_during_realgrad", action="store_true",
                       help="Optimize model B during gradient sign computation.")
    local.add_argument("--sign_mode", type=str, default="max", choices=["mean", "max"],
                       help="Method to compute gradient signs: 'mean' for mean gradient sign, 'max-voting' for majority vote per image")
    local.add_argument("--mask_mode", type=str, default="normal", choices=[
        "normal","force","random",
        "soft1","soft1_rowwise","soft2","soft2_rowwise","soft3","soft3_rowwise"
    ])
    local.add_argument("--soft_beta", type=float, default=5.0, help="Beta parameter for soft1* mode")
    local.add_argument("--soft_gamma", type=float, default=1.0, help="Gamma parameter for soft2* mode")
    local.add_argument("--soft_c", type=float, default=1.0, help="c parameter for soft3* mode")

    local_args, remaining = local.parse_known_args(argv)

    orig_argv = sys.argv
    try:
        sys.argv = [orig_argv[0]] + remaining
        base_args = parse_arguments()
    finally:
        sys.argv = orig_argv

    for k, v in vars(local_args).items():
        setattr(base_args, k.replace('-', '_'), v)
    return base_args

def taskvector_from_gradient_signs(
    gradient_signs,
    reference_taskvector,
    fallback_taskvector=None,
    mask_mode="normal",
    beta=5.0,
    gamma=1.0,
    c=1.0
):
    """Build a TaskVector from gradient signs using a reference vector for magnitudes.
    
    mask_mode can be:
        - 'force': absolute τ_A with gradient sign
        - 'normal': logical AND of signs
        - 'random': random signs
        - 'soft1': entrywise cosine-style
        - 'soft1_rowwise': row-wise cosine-style
        - 'soft2': entrywise magnitude-weighted sign agreement
        - 'soft2_rowwise': row-wise magnitude-weighted sign agreement
        - 'soft3': entrywise normalized gradient overlap
        - 'soft3_rowwise': row-wise normalized gradient overlap
    """
    tv_vector = {}

    for key in gradient_signs.keys():
        t = reference_taskvector.vector[key]
        g = gradient_signs[key]

        # ----------------- FORCE -----------------
        if mask_mode == "force":
            tv_vector[key] = torch.abs(t) * torch.sign(g)

        # ----------------- NORMAL -----------------
        elif mask_mode == "normal":
            tv_vector[key] = torch.where(torch.sign(g) == torch.sign(t), t, torch.zeros_like(t))

        # ----------------- RANDOM -----------------
        elif mask_mode == "random":
            rand_signs = torch.sign(torch.rand_like(g) * 2 - 1)
            tv_vector[key] = torch.abs(t) * rand_signs

        # ----------------- SOFT1 -----------------
        elif mask_mode in ["soft1", "soft1_rowwise"]:
            if mask_mode == "soft1_rowwise" and t.ndim >= 2:
                # flatten all dims except first
                t_flat = t.view(t.shape[0], -1)
                g_flat = g.view(g.shape[0], -1)

                masks = []
                for i in range(t.shape[0]):
                    t_row = t_flat[i]
                    g_row = g_flat[i]
                    cos_sim = torch.dot(t_row, g_row) / (t_row.norm() * g_row.norm() + 1e-8)
                    masks.append(torch.sigmoid(beta * cos_sim))

                masks = torch.stack(masks)  # [rows]
                shape = [t.shape[0]] + [1]*(t.ndim-1)
                mask = masks.view(*shape).expand_as(t)
            else:  # entrywise
                denom = (t.norm() * g.norm() + 1e-8)
                mask = torch.sigmoid(beta * (t * g / denom))

            tv_vector[key] = mask * t

        # ----------------- SOFT2 -----------------
        elif mask_mode in ["soft2", "soft2_rowwise"]:
            if mask_mode == "soft2_rowwise" and t.ndim >= 2:
                t_flat = t.view(t.shape[0], -1)
                g_flat = g.view(g.shape[0], -1)
                masks = []
                for i in range(t.shape[0]):
                    alignment = torch.dot(t_flat[i], g_flat[i])
                    masks.append(torch.clamp(torch.tanh(gamma * alignment), min=0.0))
                masks = torch.stack(masks)
                shape = [t.shape[0]] + [1]*(t.ndim-1)
                mask = masks.view(*shape).expand_as(t)
            else:
                alignment = t * g
                mask = torch.clamp(torch.tanh(gamma * alignment), min=0.0)

            tv_vector[key] = mask * t

        # ----------------- SOFT3 -----------------
        elif mask_mode in ["soft3", "soft3_rowwise"]:
            if mask_mode == "soft3_rowwise" and t.ndim >= 2:
                t_flat = t.view(t.shape[0], -1)
                g_flat = g.view(g.shape[0], -1)
                masks = []
                for i in range(t.shape[0]):
                    same_sign = (torch.sign(t_flat[i]) == torch.sign(g_flat[i])).float()
                    sign_agreement = same_sign.mean()
                    weight = torch.abs(g_flat[i]).mean() / (torch.abs(g_flat[i]).mean() + c)
                    masks.append(sign_agreement * weight)
                masks = torch.stack(masks)
                shape = [t.shape[0]] + [1]*(t.ndim-1)
                mask = masks.view(*shape).expand_as(t)
            else:
                same_sign = (torch.sign(t) == torch.sign(g)).float()
                mask = same_sign * torch.abs(g) / (torch.abs(g) + c)

            tv_vector[key] = mask * t

        else:
            raise ValueError(
                f"Invalid mode '{mask_mode}'. Use 'normal', 'force', 'random', "
                "'soft1', 'soft1_rowwise', 'soft2', 'soft2_rowwise', "
                "'soft3', 'soft3_rowwise'."
            )

    # fallback per shortcut ecc.
    if fallback_taskvector is not None:
        for key in reference_taskvector.vector.keys():
            if "shortcut" in key:
                tv_vector[key] = fallback_taskvector.vector[key]

    return TaskVector(vector=tv_vector)



def evaluate_task_vectors(
    mod_openclip_b,
    test_dataloader,
    test_dataset,
    device,
    alphas,
    taskvector_a,
    taskvector_permuted,
    realgrad_taskvector_permuted,
    realgrad_taskvector_a,
    oracle_taskvector,
    oracle_taskvector_permuted,
    logger,
    metric_prefix="",
):
    """Evaluate TA, T_perm, TA_real, T_perm_real, T_oracle, T_oracle_perm across alphas and log to W&B."""
    best_t_perm_real_acc = 0
    best_t_perm_real_alpha = None
    best_ta_acc = 0
    best_ta_alpha = None
    best_t_perm_acc = 0
    best_t_perm_alpha = None
    best_ta_real_grad_acc = 0
    best_ta_real_grad_alpha = None
    best_t_oracle_acc = 0
    best_t_oracle_alpha = None
    best_tperm_oracle_acc = 0
    best_tperm_oracle_alpha = None

    real_grad_results = []
    ta_results = []
    t_perm_results = []
    ta_real_grad_results = []
    t_oracle_grad_results = []
    t_oracle_perm_grad_results = []

    for alpha in alphas:
        logger.info(f"{alpha=}")

        best_t_perm_real_acc, best_t_perm_real_alpha = evaluate_with_task_vector(
            base_model=mod_openclip_b,
            task_vector=realgrad_taskvector_permuted,
            alpha=alpha,
            test_dataloader=test_dataloader,
            test_dataset=test_dataset,
            device=device,
            vector_name="t_perm_real_grad",
            best_acc=best_t_perm_real_acc,
            best_alpha=best_t_perm_real_alpha,
            results_list=real_grad_results,
            logger=logger,
            display_name="T perm real grad",
            metric_prefix=metric_prefix
        )

        best_ta_acc, best_ta_alpha = evaluate_with_task_vector(
            base_model=mod_openclip_b,
            task_vector=taskvector_a,
            alpha=alpha,
            test_dataloader=test_dataloader,
            test_dataset=test_dataset,
            device=device,
            vector_name="ta",
            best_acc=best_ta_acc,
            best_alpha=best_ta_alpha,
            results_list=ta_results,
            logger=logger,
            display_name="TA",
            metric_prefix=metric_prefix
        )

        best_t_perm_acc, best_t_perm_alpha = evaluate_with_task_vector(
            base_model=mod_openclip_b,
            task_vector=taskvector_permuted,
            alpha=alpha,
            test_dataloader=test_dataloader,
            test_dataset=test_dataset,
            device=device,
            vector_name="t_perm",
            best_acc=best_t_perm_acc,
            best_alpha=best_t_perm_alpha,
            results_list=t_perm_results,
            logger=logger,
            display_name="T_perm",
            metric_prefix=metric_prefix
        )

        best_ta_real_grad_acc, best_ta_real_grad_alpha = evaluate_with_task_vector(
            base_model=mod_openclip_b,
            task_vector=realgrad_taskvector_a,
            alpha=alpha,
            test_dataloader=test_dataloader,
            test_dataset=test_dataset,
            device=device,
            vector_name="ta_real_grad",
            best_acc=best_ta_real_grad_acc,
            best_alpha=best_ta_real_grad_alpha,
            results_list=ta_real_grad_results,
            logger=logger,
            display_name="TA_real_grad",
            metric_prefix=metric_prefix
        )

        best_t_oracle_acc, best_t_oracle_alpha = evaluate_with_task_vector(
            base_model=mod_openclip_b,
            task_vector=oracle_taskvector,
            alpha=alpha,
            test_dataloader=test_dataloader,
            test_dataset=test_dataset,
            device=device,
            vector_name="t_oracle",
            best_acc=best_t_oracle_acc,
            best_alpha=best_t_oracle_alpha,
            results_list=t_oracle_grad_results,
            logger=logger,
            display_name="T_oracle",
            metric_prefix=metric_prefix
        )

        best_tperm_oracle_acc, best_tperm_oracle_alpha = evaluate_with_task_vector(
            base_model=mod_openclip_b,
            task_vector=oracle_taskvector_permuted,
            alpha=alpha,
            test_dataloader=test_dataloader,
            test_dataset=test_dataset,
            device=device,
            vector_name="t_oracle_perm",
            best_acc=best_tperm_oracle_acc,
            best_alpha=best_tperm_oracle_alpha,
            results_list=t_oracle_perm_grad_results,
            logger=logger,
            display_name="T_oracle_perm",
            metric_prefix=metric_prefix
        )

    all_best_accs = [
        (best_t_perm_real_acc, best_t_perm_real_alpha, "t_perm_real_grad"),
        (best_ta_acc, best_ta_alpha, "ta"),
        (best_t_perm_acc, best_t_perm_alpha, "t_perm"),
        (best_ta_real_grad_acc, best_ta_real_grad_alpha, "ta_real_grad"),
        (best_t_oracle_acc, best_t_oracle_alpha, "t_oracle"),
        (best_tperm_oracle_acc, best_tperm_oracle_alpha, "t_oracle_perm"),
    ]
   
    best_overall = max(all_best_accs, key=lambda x: x[0])
    best_overall_acc, best_overall_alpha, best_method = best_overall

    if wandb is not None:
        wandb.log({
            f"{metric_prefix}best_t_perm_real_alpha": best_t_perm_real_alpha,
            f"{metric_prefix}best_t_perm_real_accuracy": best_t_perm_real_acc,
            f"{metric_prefix}best_ta_alpha": best_ta_alpha,
            f"{metric_prefix}best_ta_accuracy": best_ta_acc,
            f"{metric_prefix}best_t_perm_alpha": best_t_perm_alpha,
            f"{metric_prefix}best_t_perm_accuracy": best_t_perm_acc,
            f"{metric_prefix}best_ta_real_grad_alpha": best_ta_real_grad_alpha,
            f"{metric_prefix}best_ta_real_grad_accuracy": best_ta_real_grad_acc,
            f"{metric_prefix}best_t_oracle_alpha": best_t_oracle_alpha,
            f"{metric_prefix}best_t_oracle_accuracy": best_t_oracle_acc,
            f"{metric_prefix}best_tperm_oracle_alpha": best_tperm_oracle_alpha,
            f"{metric_prefix}best_tperm_oracle_accuracy": best_tperm_oracle_acc,
            f"{metric_prefix}best_overall_alpha": best_overall_alpha,
            f"{metric_prefix}best_overall_accuracy": best_overall_acc,
            f"{metric_prefix}best_method": best_method
        })

        wandb.log({
            f"{metric_prefix}t_perm_real_grad_evaluation": wandb.Table(
                data=[[r["alpha"], r["loss"], r["accuracy"]] for r in real_grad_results],
                columns=["Alpha", "Loss", "Accuracy"]
            ),
            f"{metric_prefix}ta_evaluation": wandb.Table(
                data=[[r["alpha"], r["loss"], r["accuracy"]] for r in ta_results],
                columns=["Alpha", "Loss", "Accuracy"]
            ),
            f"{metric_prefix}t_perm_evaluation": wandb.Table(
                data=[[r["alpha"], r["loss"], r["accuracy"]] for r in t_perm_results],
                columns=["Alpha", "Loss", "Accuracy"]
            ),
            f"{metric_prefix}ta_real_grad_evaluation": wandb.Table(
                data=[[r["alpha"], r["loss"], r["accuracy"]] for r in ta_real_grad_results],
                columns=["Alpha", "Loss", "Accuracy"]
            ),
            f"{metric_prefix}t_oracle_evaluation": wandb.Table(
                data=[[r["alpha"], r["loss"], r["accuracy"]] for r in t_oracle_grad_results],
                columns=["Alpha", "Loss", "Accuracy"]
            ),
            f"{metric_prefix}t_oracle_perm_evaluation": wandb.Table(
                data=[[r["alpha"], r["loss"], r["accuracy"]] for r in t_oracle_perm_grad_results],
                columns=["Alpha", "Loss", "Accuracy"]
            )
        })
        
    results = {
        "t_perm_real_grad": (best_t_perm_real_acc, best_t_perm_real_alpha),
        "ta": (best_ta_acc, best_ta_alpha),
        "t_perm": (best_t_perm_acc, best_t_perm_alpha),
        "ta_real_grad": (best_ta_real_grad_acc, best_ta_real_grad_alpha),
        "t_oracle": (best_t_oracle_acc, best_t_oracle_alpha),
        "t_oracle_perm": (best_tperm_oracle_acc, best_tperm_oracle_alpha),
    }
    return results


def build_class_indices(dataset):
    """Return dict[class_id -> list of dataset indices]."""
    per_class = {}
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        per_class.setdefault(int(label), []).append(idx)
    return per_class


def sample_indices_per_class(dataset, k):
    """Randomly sample k indices per class from dataset, return dict[class_id -> indices]."""
    per_class = build_class_indices(dataset)
    sampled = {}
    for cls, idxs in per_class.items():
        if len(idxs) < k:
            logger.warning(f"Class {cls} has only {len(idxs)} samples, requested {k}. Using all.")
            sampled[cls] = idxs
        else:
            sampled[cls] = random.sample(idxs, k)
    return sampled


def load_precomputed_indices(path):
    """Load precomputed per-class indices from a pickle file."""
    with open(path, 'rb') as f:
        data = pickle.load(f)
    if not isinstance(data, dict):
        raise ValueError("Precomputed indices pickle must be a dict[class_id -> list[int]]")
    # normalize keys to int
    norm = {}
    for k, v in data.items():
        try:
            key = int(k)
        except Exception:
            key = k
        norm[key] = list(map(int, v))
    return norm


def autodetect_indices_file(dataset_name: str):
    """Try to find a precomputed indices pickle matching the dataset name.

    Search path:
    <repo_root>/precomputed_indices/
    """
    dataset_name = dataset_name.lower()
    patterns = [f"*{dataset_name}*indices*.pkl"]

    # Only check the repo-local precomputed_indices folder
    repo_root = Path(__file__).resolve().parent
    repo_precomp = repo_root / "precomputed_indices"
    for patt in patterns:
        candidates = list(repo_precomp.glob(patt)) if repo_precomp.exists() else []
        if candidates:
            candidates.sort(key=lambda p: ("medoids" not in p.name, "herding" not in p.name, "coreset" not in p.name))
            return candidates[0]

    return None

class FirstNBatches:
    """Iterable wrapper to yield only the first N batches from an existing dataloader."""
    def __init__(self, dataloader, n):
        self.dataloader = dataloader
        self.n = int(n)
    def __iter__(self):
        return itertools.islice(iter(self.dataloader), self.n)
    def __len__(self):
        try:
            return min(self.n, len(self.dataloader))
        except TypeError:
            return self.n

from torch.utils.data import Subset, DataLoader

def build_realgrad_dataloader(train_dataset, base_loader, args, imgs_indices=None):
    """Build a dataloader for real-gradient computation.

    - If imgs_indices is provided (dict[class_id -> list[int]]), build a Subset dataloader over those indices.
    - Else if args.num_batches is provided, wrap base_loader to only yield first N batches.
    - Else return base_loader as-is.
    """
    if imgs_indices is not None:
        flat_idx = []
        for cls, idxs in imgs_indices.items():
            flat_idx.extend(list(map(int, idxs)))
        subset = Subset(train_dataset, flat_idx)
        return DataLoader(
            subset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=False,
        )
    if getattr(args, "num_batches", None) is not None:
        return FirstNBatches(base_loader, args.num_batches)
    return base_loader

def compute_real_gradient_signs(
    pretrained_model,
    dataset,
    loss_fn,
    device,
    dataloader,
    optimize=False,
    prompt_ensemble=True,
    vote="mean", 
):
    """Compute signed gradients for the visual backbone from a dataloader.

    - vote="mean": sign of the mean gradient
    - vote="max": majority vote of the signs per image 
    """

    pretrained_model.eval()
    pretrained_model.zero_grad()
    model = deepcopy(pretrained_model).to(device)
    
    model.eval()

    if optimize:
        model.train()
        optimizer = torch.optim.AdamW(model.visual.parameters(), lr=1e-5)
        optimizer.zero_grad()
    
    trainable_parameters = [p for _, p in model.visual.named_parameters() if p.requires_grad]
    named_trainable_parameters = [(name, p) for name, p in model.visual.named_parameters() if p.requires_grad]
    def build_text_features():
        if prompt_ensemble and hasattr(dataset, "templates"):
            prompts = [[template(c.lower()) for c in dataset.class_names] for template in dataset.templates]
            with torch.no_grad():
                template_embeddings = []
                for template in prompts:
                    texts = open_clip.tokenize(template).to(device)
                    text_feats = F.normalize(model.encode_text(texts), dim=-1)
                    template_embeddings.append(text_feats)
                return torch.mean(torch.stack(template_embeddings), dim=0)
        else:
            prompts = [dataset.single_template(c.lower()) for c in dataset.class_names]
            with torch.no_grad():
                texts = open_clip.tokenize(prompts).to(device)
                return F.normalize(model.encode_text(texts), dim=-1)

    text_features = build_text_features()
    total_steps = max(1, len(dataloader))
    scale = 1.0 / float(total_steps)

    if vote == "max":
        sign_sums = {
            name: torch.zeros_like(p, device=device)
            for name, p in model.visual.named_parameters()
            if p.requires_grad
        }

    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device).long()

        image_features = F.normalize(model.encode_image(images), dim=-1)
        vl_logits = model.logit_scale.exp() * (image_features @ text_features.t())

        if vote == "mean":
            total_loss = loss_fn(vl_logits, labels)
            (total_loss * scale).backward()
        elif vote == "max":
            losses = F.cross_entropy(vl_logits, labels, reduction="none") * scale
            for i in range(len(images)):
                grads = torch.autograd.grad(
                    losses[i],
                    trainable_parameters,
                    retain_graph=True,
                    create_graph=False,
                )
                
                for (name, _), g in zip(named_trainable_parameters, grads):
                    if g is not None:
                        sign_sums[name] += torch.sign(-g.detach())
            if optimize:
                total_loss = losses.mean()
                total_loss.backward()
            del grads
        else:

            ValueError(f"Invalid vote method '{vote}'. Use 'mean' or 'max'.")

    gradient_signs = {}
    if vote == "mean":
        for name, param in model.visual.named_parameters():
            if param.grad is not None:
                gradient_signs[name] = torch.sign(-param.grad)
    elif vote == "max":
        gradient_signs = {name: torch.sign(acc) for name, acc in sign_sums.items()}

    if optimize:
        optimizer.step()

    return gradient_signs, (model if optimize else None)



def _load_precomputed_indices(spec: str, train_dataset, device, feature_extractor=None):
    """
    Load per-class image indices to compute real gradient signs.
    Accepted 'spec' formats:
      - Absolute/relative path to an existing .pkl (dict[int -> list[int]])
      - A method spec 'coreset:K' | 'herding:K' | 'k-medoid:K' (also accepts ',' or '_' as separator)
    When a method spec is provided:
      - It looks for precomputed_indices/{dataset}_{tag}_indices_{K}.pkl
        where tag is 'coreset' | 'herding' | 'medoids' (for 'k-medoid').
      - If not found, it computes indices on-the-fly via SyntheticImage and uses them.
    """
    p = Path(spec)
    # Case 1: direct .pkl path
    if p.suffix == ".pkl":
        if not p.exists():
            raise FileNotFoundError(f"Pickle not found: {p}")
        with open(p, "rb") as f:
            return pickle.load(f)

    # Case 2: method:K (or method,K | method_K)
    for sep in (":", ",", "_"):
        if sep in spec:
            method_raw, k_raw = spec.split(sep, 1)
            method = method_raw.strip().lower()
            if method not in {"coreset", "herding", "k-medoid"}:
                raise ValueError(f"Invalid method '{method}'. Use: coreset, herding, k-medoid.")
            try:
                k = int(k_raw)
            except Exception as e:
                raise ValueError(f"Invalid K in spec '{spec}': {e}")

            ds_name = getattr(train_dataset, "name", "").lower()
            if not ds_name:
                raise RuntimeError("Dataset does not expose 'name'; cannot build precomputed filename.")
            tag = "medoids" if method == "k-medoid" else method
            std = Path(PRECOMPUTED_DIR) / f"{ds_name}_{tag}_indices_{k}.pkl"
            legacy = Path(PRECOMPUTED_DIR) / f"{ds_name}_{tag}_indices.pkl"
            if std.exists():
                with open(std, "rb") as f:
                    return pickle.load(f)
            if legacy.exists():
                with open(legacy, "rb") as f:
                    return pickle.load(f)

            # Compute on-the-fly via SyntheticImage; it will also save to PRECOMPUTED_DIR
            logger.info(f"No precomputed found ({std.name}). Computing '{method}' indices with K={k}...")
            num_classes = len(getattr(train_dataset, "class_names", []))
            if num_classes == 0:
                raise RuntimeError("Dataset does not expose class_names; cannot compute K-per-class indices.")
            num_synth = k * num_classes
            synth = SyntheticImage(
                dataset=train_dataset,
                num_synthetic=num_synth,
                device=device,
                initialization=method, #coreset, k-medoids, herding
                feature_extractor=feature_extractor,
            )
            return synth.synthetic_indices

    raise ValueError(f"--precomputed_indices='{spec}' is neither an existing .pkl nor a valid 'method:K' spec.")

def sign_agreement(grad_signs_a, grad_signs_b, names=("A", "B")):
    """
    Compute the percentage of sign agreement between two sets of gradient signs.
    """
    logger.info(f"Computing sign agreement")
    common_keys = set(grad_signs_a.keys()) & set(grad_signs_b.keys())
    agreement = 0
    total = 0
    for key in common_keys:
        # consider only the sign (1, -1, 0)
        a = torch.sign(grad_signs_a[key])
        b = torch.sign(grad_signs_b[key])
        agreement += (a == b).sum().item()
        total += a.numel()
    final_agreement = agreement / total if total > 0 else 0
    logger.info(f"sign_agreement_{names[0]}_wrt_{names[1]}: {agreement}/{total} = {final_agreement*100:.2f}%")
    if wandb is not None:
        wandb.log({f"sign_agreement_{names[0]}_wrt_{names[1]}": final_agreement})
    return final_agreement
def cosine_similarity_taskvectors(vec_a: dict, vec_b: dict, names=("A", "B")) -> float:
    """Compute cosine similarity between two task vectors (dict[name->tensor]).

    - Concatenate all parameters shared by both vectors in a fixed order.
    - Skip keys with mismatched shapes; log once if any skipped.
    """
    common = sorted(set(vec_a.keys()) & set(vec_b.keys()))
    a_chunks = []
    b_chunks = []
    skipped = 0
    for k in common:
        ta = vec_a[k]
        tb = vec_b[k]
        if ta.shape != tb.shape:
            skipped += 1
            continue
        a_chunks.append(ta.detach().float().reshape(-1))
        b_chunks.append(tb.detach().float().reshape(-1))
    if not a_chunks:
        logger.warning(f"cosine_similarity: no comparable keys between {names[0]} and {names[1]}")
        return float("nan")
    if skipped:
        logger.debug(f"cosine_similarity: skipped {skipped} mismatched tensors between {names[0]} and {names[1]}")
    a = torch.cat(a_chunks)
    b = torch.cat(b_chunks)
    # avoid zero-norm
    a_norm = a.norm(p=2)
    b_norm = b.norm(p=2)
    if a_norm.item() == 0 or b_norm.item() == 0:
        logger.warning(f"cosine_similarity: zero-norm vector for {names}")
        return float("nan")
    cos = torch.dot(a, b) / (a_norm * b_norm)
    val = cos.item()
    logger.info(f"cosine_{names[0]}_{names[1]}: {val:.6f}")
    if wandb is not None:
        wandb.log({f"cosine_{names[0]}_{names[1]}": val})
    return val

def main():
    args = parse_local_args()
    set_seed(args.seed)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    config = {
        'Architecture': args.arch,
        'Pretraining_A': args.pretraining_backbone_A,
        'Pretraining_B': args.pretraining_backbone_B,
        'Base_Folder': args.base_folder,
        'Dataset': args.dataset,
        'Real_Imgs_Per_Class': args.real_imgs_per_class,
        'Num_Batches': args.num_batches if hasattr(args, 'num_batches') else 1,
        'Seed': args.seed,
        'Sign_Mode': args.sign_mode,
        'Mask_Mode': args.mask_mode,
        'Optimize_During_RealGrad': args.optimize_during_realgrad,
        'Soft_Beta': getattr(args, 'soft_beta', None),
        'Soft_Gamma': getattr(args, 'soft_gamma', None),
        'Soft_C': getattr(args, 'soft_c', None),

    }
    # config = {name: val for name,val in args._get_kwargs() if val is not None}
    # Logga method e Real_images_per_class se precomputed_indices è nel formato metodo:K
    def _parse_precomputed(spec: str):
        for sep in (":", ",", "_"):
            if spec and sep in spec:
                m, k = spec.split(sep, 1)
                m = m.strip().lower()
                try:
                    return m, int(k)
                except Exception:
                    return None, None
        return None, None

    if getattr(args, "precomputed_indices", None):
        m, k = _parse_precomputed(args.precomputed_indices)
        if m is not None and k is not None:
            config["Method"] = m              # es. 'coreset'
            config['Real_Imgs_Per_Class'] = k  # es. 20

    wandb.init(project="gradient_signs_eval_ft",
                name=args.wandb_run_name,
                config=config,
                mode=args.wandb_mode,
                dir=args.base_folder,
                group=args.wandb_group)

    model_a, model_b, model_a_ft, model_b_ft, preprocess_A, preprocess_B = get_models(args, device)

    mod_openclip_a = OpenCLIPModel(model_a).clip_model
    mod_openclip_a_ft = OpenCLIPModel(model_a_ft).clip_model
    mod_openclip_b = OpenCLIPModel(model_b).clip_model
    mod_openclip_b_ft = OpenCLIPModel(model_b_ft).clip_model

    train_loader_A, test_loader_A, val_loader_A, train_dataset_A, test_dataset_A, val_dataset_A, support_loader_A, support_dataset_A = load_dataset(args, preprocess_A, validation=True, support=True)
    train_loader_B, test_loader_B, val_loader_B, train_dataset_B, test_dataset_B, val_dataset_B, support_loader_B, support_dataset_B = load_dataset(args, preprocess_B, validation=True, support=True)
    logger.debug(f"Train dataset classes: {train_dataset_A.class_names}")
    logger.debug(f"Test dataset classes: {test_dataset_A.class_names}")


    # Zero-shot baselines

    loss_task, acc_task = evaluate_model(mod_openclip_b, test_loader_B, test_dataset_B, device, prompt_ensemble=True)
    loss_supp, acc_sup = evaluate_model(mod_openclip_b, support_loader_B, support_dataset_B, device, prompt_ensemble=True)
    logger.info(f"Model B ZERO SHOT | TASK : {acc_task}, loss {loss_task}")
    logger.info(f"Model B ZERO SHOT | SUPPORT : {acc_sup}, loss {loss_supp}")
    if wandb is not None:
        wandb.log({
            "zero_shot task_acc": acc_task,
            "zero_shot support_acc": acc_sup,
            "zero_shot task_loss": loss_task,
            "zero_shot support_loss": loss_supp
        })
    

    # Build task-vectors and permutation
    taskvector_a = TaskVector(mod_openclip_a.visual, mod_openclip_a_ft.visual)
    taskvector_b = TaskVector(mod_openclip_b.visual, mod_openclip_b_ft.visual)
    
    #A ft 
    loss_task_aft, acc_task_aft = evaluate_model(mod_openclip_a_ft, test_loader_A, test_dataset_A, device, prompt_ensemble=True)
    loss_supp_aft, acc_sup_aft = evaluate_model(mod_openclip_a_ft, support_loader_A, support_dataset_A, device, prompt_ensemble=True)
    logger.info(f"Model A ft | TASK : {acc_task_aft}, loss {loss_task_aft}")
    logger.info(f"Model A ft | SUPPORT : {acc_sup_aft}, loss {loss_supp_aft}")
    if wandb is not None:
        wandb.log({
            "model_A_ft_task_acc": acc_task_aft,
            "model_A_ft_task_loss": loss_task_aft,
            "model_A_ft_support_acc": acc_sup_aft,
            "model_A_ft_support_loss": loss_supp_aft
        })

    # B ft
    loss_task_bft, acc_task_bft = evaluate_model(mod_openclip_b_ft, test_loader_B, test_dataset_B, device, prompt_ensemble=True)
    loss_supp_bft, acc_sup_bft = evaluate_model(mod_openclip_b_ft, support_loader_B, support_dataset_B, device, prompt_ensemble=True)
    logger.info(f"Model B ft | TASK : {acc_task_bft}, loss {loss_task_bft}")
    logger.info(f"Model B ft | SUPPORT : {acc_sup_bft}, loss {loss_supp_bft}")
    if wandb is not None:
        wandb.log({
            "model_B_ft_task_acc": acc_task_bft,
            "model_B_ft_task_loss": loss_task_bft,
            "model_B_ft_support_acc": acc_sup_bft,
            "model_B_ft_support_loss": loss_supp_bft
        })

    permutation_spec_visual = CLIP_Visual_PermutationSpecBuilder(depth=mod_openclip_a.visual.transformer.layers).create_permutation_spec()

    permutations_path = Path(args.base_folder, "permutations", args.arch)
    permutations_path.mkdir(parents=True, exist_ok=True)
    perm_file = Path(permutations_path, f'permutations_visual_{args.pretraining_backbone_A}_to_{args.pretraining_backbone_B}_{args.seed}.pkl')

    if os.path.exists(perm_file):
        with open(perm_file, 'rb') as f:
            permutation_visual, heads_permutation_visual = pickle.load(f)
        logger.info(f"[TransFusion] Loaded visual and heads permutation from {perm_file}")
    else:
        weight_matcher = WeightMatcher(
            ps=permutation_spec_visual,
            max_iter=100,
            fixed=model_b.visual.state_dict(),
            permutee=model_a.visual.state_dict(),
            num_heads=model_a.visual.transformer.resblocks[0].attn.num_heads,
            intra_head=True,
            layer_iteration_order=LayerIterationOrder.RANDOM
        )
        permutation_visual, heads_permutation_visual = weight_matcher.run()
        with open(perm_file, 'wb') as f:
            pickle.dump((permutation_visual, heads_permutation_visual), f)
        logger.info(f"[TransFusion] Saved visual and heads permutation to {perm_file}")

    taskvector_permuted = TaskVector(vector=apply_permutation_to_statedict(
        permutation_spec_visual,
        permutation_visual,
        deepcopy(taskvector_a.vector),
        heads_permutation=heads_permutation_visual,
        num_heads=mod_openclip_a.visual.transformer.resblocks[0].attn.num_heads,
    ))

    tb_signs = {k: torch.sign(v) for k, v in taskvector_b.vector.items()}
    
    # Determine indices for real gradient computation
    imgs_indices = None
    if args.precomputed_indices is not None:
        imgs_indices = _load_precomputed_indices(args.precomputed_indices, train_dataset_B, device, feature_extractor=model_b.encode_image) #feature_extractor=model_b BBBB
        logger.info(f"Using precomputed indices spec '{args.precomputed_indices}'.")
    else:
        # fallback
        if args.real_imgs_per_class is not None:
            imgs_indices = sample_indices_per_class(train_dataset_B, args.real_imgs_per_class)
            logger.info(f"Sampled {args.real_imgs_per_class} images per class for real gradient signs")
        else:
            imgs_indices = None 

    # Build the dedicated dataloader for real gradient computation
    realgrad_loader = build_realgrad_dataloader(train_dataset_B, train_loader_B, args, imgs_indices)

    # Compute real gradient signs using only the provided dataloader (indices/num_batches handled upstream)
    real_gradient_signs, optimized_b = compute_real_gradient_signs(
        pretrained_model=mod_openclip_b,
        dataset=train_dataset_B,
        loss_fn=nn.CrossEntropyLoss(),
        device=device,
        dataloader=realgrad_loader,
        optimize=getattr(args, "optimize_during_realgrad", False),
        prompt_ensemble=True,
        vote=args.sign_mode  
    )

    # Compute sign agreement
    sign_agreement(tb_signs, taskvector_a.vector, names=("A", "B"))
    sign_agreement(tb_signs, real_gradient_signs, names=("oracle", "realgrad"))
    sign_agreement(taskvector_a.vector, real_gradient_signs, names=("A", "realgrad"))
    sign_agreement(taskvector_permuted.vector, real_gradient_signs, names=("A_perm", "realgrad"))
    sign_agreement(taskvector_b.vector, real_gradient_signs, names=("B", "realgrad"))
    sign_agreement(taskvector_b.vector, taskvector_a.vector, names=("B", "oracle"))

    # Evaluate optimized model B
    if optimized_b is not None:
        loss_task_opt, acc_task_opt = evaluate_model(optimized_b, train_loader_B, train_dataset_B, device, prompt_ensemble=True)
        loss_supp_opt, acc_sup_opt = evaluate_model(optimized_b, support_loader_B, support_dataset_B, device, prompt_ensemble=True)
        logger.info(f"Model B optimized | TASK : {acc_task_opt}, loss {loss_task_opt}")
        logger.info(f"Model B optimized | SUPPORT : {acc_sup_opt}, loss {loss_supp_opt}")
        if wandb is not None:
            wandb.log({
                "optimized task_acc": acc_task_opt,
                "optimized support_acc": acc_sup_opt,
                "optimized task_loss": loss_task_opt,
                "optimized support_loss": loss_supp_opt
            })

    # Build task-vectors
    realgrad_taskvector_permuted = taskvector_from_gradient_signs(
        real_gradient_signs, taskvector_permuted,
        fallback_taskvector=taskvector_permuted,
        mask_mode=args.mask_mode,
        beta=getattr(args, 'soft_beta', 5.0),
        gamma=getattr(args, 'soft_gamma', 1.0),
        c=getattr(args, 'soft_c', 1.0)
    )
    oracle_taskvector = taskvector_from_gradient_signs(
        tb_signs, taskvector_a,
        fallback_taskvector=taskvector_a,
        mask_mode=args.mask_mode,
        beta=getattr(args, 'soft_beta', 5.0),
        gamma=getattr(args, 'soft_gamma', 1.0),
        c=getattr(args, 'soft_c', 1.0)
    )
    oracle_taskvector_permuted = taskvector_from_gradient_signs(
        tb_signs, taskvector_permuted,
        fallback_taskvector=taskvector_permuted,
        mask_mode=args.mask_mode,
        beta=getattr(args, 'soft_beta', 5.0),
        gamma=getattr(args, 'soft_gamma', 1.0),
        c=getattr(args, 'soft_c', 1.0)
    )
    realgrad_taskvector_a = taskvector_from_gradient_signs(
        real_gradient_signs, taskvector_a,
        fallback_taskvector=taskvector_a,
        mask_mode=args.mask_mode,
        beta=getattr(args, 'soft_beta', 5.0),
        gamma=getattr(args, 'soft_gamma', 1.0),
        c=getattr(args, 'soft_c', 1.0)
    )

    results = evaluate_task_vectors(
        mod_openclip_b=mod_openclip_b,
        test_dataloader=test_loader_B,
        test_dataset=test_dataset_B,
        device=device,
        alphas=np.linspace(0.1, 1, args.eval_alphas),
        taskvector_a=taskvector_a,
        taskvector_permuted=taskvector_permuted,
        realgrad_taskvector_permuted=realgrad_taskvector_permuted,
        realgrad_taskvector_a=realgrad_taskvector_a,
        oracle_taskvector=oracle_taskvector,
        oracle_taskvector_permuted=oracle_taskvector_permuted,
        logger=logger,
        metric_prefix="B_",
    )
    if args.optimize_during_realgrad and optimized_b is not None:
        logger.info(f"Note: Model B was optimized during real gradient computation.")
        results_opt = evaluate_task_vectors(
            mod_openclip_b=optimized_b,
            test_dataloader=test_loader_B,
            test_dataset=test_dataset_B,
            device=device,
            alphas=np.linspace(0.1, 1, args.eval_alphas),
            taskvector_a=taskvector_a,
            taskvector_permuted=taskvector_permuted,
            realgrad_taskvector_permuted=realgrad_taskvector_permuted,
            realgrad_taskvector_a=realgrad_taskvector_a,
            oracle_taskvector=oracle_taskvector,
            oracle_taskvector_permuted=oracle_taskvector_permuted,
            logger=logger,
            metric_prefix="opt_B_",
        )

if __name__ == "__main__":
    main()
