from copy import deepcopy
import torch
import logging
import time

import numpy as np
from utils.adv_attacks import (
    add_negative_neighbor_wrt_pos_src_neg_sample,
    add_negative_neighbor_wrt_pos_src_neg_sample_nat,
    fgsm_tgn,
    fgsm_tgnw,
    fgsm_nat,
    random_attack,
    PRBCD,
    GRBCD,
    PRBCD_NAT,
    GRBCD_NAT,
)


def memstranding_attack(
    model,
    data,
    assoc,
    pos_batch,
    attack_params,
    device,
    debug=False,
):
    """
    Optimized MemStranding attack implementation for TGN.

    Performance optimizations:
    - Batch memory queries instead of individual calls
    - Use tensor operations instead of Python loops
    - Efficient adjacency representation with PyTorch tensors
    - Pre-allocate tensors for fake neighbor generation
    - Avoid model deepcopy by using state snapshots

    Args:
        model: TGN model
        data: Full temporal data
        assoc: Association tensor for node mapping
        pos_batch: Current positive batch
        attack_params: Attack configuration parameters
        device: Device to run on
        debug: Debug mode flag

    Returns:
        attack_params: Updated attack parameters
    """
    import random

    # Extract attack parameters
    attack_mode = attack_params.get("attack_mode", "single_shot")
    budget_pct = attack_params.get("budget_pct", 0.01)
    attack_timestamp = attack_params.get("attack_timestamp", None)
    fake_neighbors_per_victim = attack_params.get("fake_neighbors_per_victim", 10)
    adv_budget = attack_params.get("adv_budget", 0.02)  # For distributed mode
    is_bipartite = attack_params.get("bipartite", False)  # Explicit bipartite flag

    logging.info(f"Executing MemStranding attack in {attack_mode} mode")

    # Store original fake edge count if not already stored
    if "original_fake_edges" not in attack_params:
        attack_params["original_fake_edges"] = []
        attack_params["fake_edge_index"] = 0

    pos_src, pos_dst, pos_t, pos_msg = (
        pos_batch.src,
        pos_batch.dst,
        pos_batch.t,
        pos_batch.msg,
    )

    # Step 1: Slice Historical Events
    cur_idx = model.get_current_edge_index()

    # If attack_timestamp not specified, use the current timestamp
    if attack_timestamp is None:
        attack_timestamp = data.t[cur_idx].min().item()

    # Get historical events before attack timestamp
    historical_mask = data.t[:cur_idx] < attack_timestamp
    E_known_indices = torch.where(historical_mask)[0]

    logging.info(f"Attack timestamp: {attack_timestamp}")
    logging.info(f"Historical events: {len(E_known_indices)}")

    # Step 2: Select Victim Nodes (OPTIMIZED)
    # Build degree counts from historical events using tensor operations
    historical_src = data.src[E_known_indices]
    historical_dst = data.dst[E_known_indices]

    # Get actual node ID ranges from historical data
    src_min, src_max = historical_src.min().item(), historical_src.max().item()
    dst_min, dst_max = historical_dst.min().item(), historical_dst.max().item()

    logging.info(
        f"Node ID ranges - Source: [{src_min}, {src_max}], Destination: [{dst_min}, {dst_max}]"
    )
    logging.info(f"Bipartite graph mode: {is_bipartite}")

    # Count degrees efficiently
    all_nodes = torch.cat([historical_src, historical_dst])
    unique_nodes, counts = torch.unique(all_nodes, return_counts=True)

    # Sort by degree (descending)
    sorted_indices = torch.argsort(counts, descending=True)
    sorted_nodes = unique_nodes[sorted_indices]

    # Select root victims (top budget_pct/3 fraction of nodes with highest degree)
    # Always ensure victims come from valid node IDs in the data
    if is_bipartite:
        # For bipartite graphs, select victims from source nodes only

        # Compute degrees for source nodes only (bipartite case)
        src_unique_nodes, src_counts = torch.unique(historical_src, return_counts=True)

        # Sort by degree (descending) and select top victims
        sorted_src_indices = torch.argsort(src_counts, descending=True)
        num_root_victims = max(1, int(len(src_unique_nodes) * budget_pct))
        root_victims = src_unique_nodes[sorted_src_indices[:num_root_victims]]

        logging.info(
            f"Selected {len(root_victims)} root victims from source nodes (bipartite)"
        )
    else:
        # For non-bipartite graphs, select from all nodes that appear in the data
        num_root_victims = max(1, int(len(unique_nodes) * budget_pct / 3))
        root_victims = sorted_nodes[:num_root_victims]
        logging.info(
            f"Selected {len(root_victims)} root victims from all nodes (non-bipartite)"
        )

    # Build efficient adjacency using tensor operations
    edge_index = torch.stack([historical_src, historical_dst])
    # Create bidirectional edges
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

    # Get neighbors for root victims efficiently
    victim_set = set(root_victims.tolist())

    if is_bipartite:
        # For bipartite graphs, only use root victims (source nodes) as source nodes for fake edges
        victim_nodes = root_victims
        logging.info(
            f"Using only root victims as source nodes for fake edges (bipartite)"
        )
    else:
        # For non-bipartite graphs, add neighbors to victim set
        # For each root victim, get top-2 neighbors by degree
        for victim in root_victims:
            victim_id = victim.item()
            # Find neighbors of this victim
            neighbor_mask = edge_index[0] == victim_id
            neighbors = edge_index[1, neighbor_mask]

            if len(neighbors) > 0:
                # Get degrees of neighbors
                neighbor_degrees = []
                for neighbor in neighbors.unique():
                    degree_idx = (unique_nodes == neighbor).nonzero(as_tuple=True)[0]
                    if len(degree_idx) > 0:
                        degree = counts[degree_idx[0]].item()
                        neighbor_degrees.append((neighbor.item(), degree))

                # Sort and take top-2
                neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                for neighbor, _ in neighbor_degrees[:2]:
                    victim_set.add(neighbor)

        victim_nodes = torch.tensor(list(victim_set), device=device)
    logging.info(f"Total victim nodes: {len(victim_nodes)}")

    # Step 3: Estimate Converged Memory State (OPTIMIZED)
    # For MemStranding, we directly modify the original model memory (no restoration)
    # This is different from other attacks that use deepcopy and restore state

    # Reset and process historical events on the original model
    model.memory.reset_state()
    model.neighbor_loader.reset_state()

    # Process historical events to get converged memory
    if len(E_known_indices) > 0:
        hist_src = data.src[E_known_indices]
        hist_dst = data.dst[E_known_indices]
        hist_t = data.t[E_known_indices]
        hist_msg = data.msg[E_known_indices]

        # Update memory with historical events
        model.update_memory(hist_src, hist_dst, hist_t, hist_msg)
        model.insert_neighbor(hist_src, hist_dst)

    # Batch memory query for all victim nodes (MAJOR OPTIMIZATION)
    converged_memories, _ = model.memory(victim_nodes)
    converged_memories_dict = {
        victim_nodes[i].item(): converged_memories[i].clone()
        for i in range(len(victim_nodes))
    }

    # Step 4: Simulate Fake Neighbors (OPTIMIZED)
    # Handle bipartite graph structure properly
    # Pre-generate all fake neighbor IDs efficiently
    total_fake_neighbors_needed = len(victim_nodes) * fake_neighbors_per_victim
    victim_ids_set = set(victim_nodes.cpu().numpy())

    if is_bipartite:
        # For bipartite graphs, fake neighbors should be from actual destination nodes in the data
        # since victims are typically source nodes and we want to connect them to fake destinations
        dst_nodes = historical_dst.unique()
        available_nodes = torch.tensor(
            [i for i in dst_nodes.cpu().numpy() if i not in victim_ids_set],
            device=device,
        )
        logging.info(
            f"Using actual destination nodes from data for fake neighbors (bipartite)"
        )
    else:
        # For non-bipartite graphs, use all nodes that appear in the data
        all_historical_nodes = torch.cat([historical_src, historical_dst]).unique()
        available_nodes = torch.tensor(
            [i for i in all_historical_nodes.cpu().numpy() if i not in victim_ids_set],
            device=device,
        )
        logging.info(f"Using all nodes from data for fake neighbors (non-bipartite)")

    # Sample fake neighbor IDs efficiently - allow sharing neighbors across victims
    if len(available_nodes) < total_fake_neighbors_needed:
        logging.warning(
            f"Not enough available nodes ({len(available_nodes)}) for {total_fake_neighbors_needed} fake neighbors"
        )
        logging.warning(
            "Allowing neighbor sharing across victims to maximize attack effectiveness"
        )
        # Use all available nodes, allowing sharing across victims
        fake_neighbor_ids = available_nodes[
            torch.randperm(len(available_nodes), device=device)
        ]
        logging.info(
            f"Using all {len(available_nodes)} available nodes with sharing allowed"
        )
    else:
        # Sample unique neighbors when we have enough nodes
        fake_neighbor_ids = available_nodes[
            torch.randperm(len(available_nodes), device=device)[
                :total_fake_neighbors_needed
            ]
        ]

    if len(fake_neighbor_ids) == 0:
        logging.error(
            "No fake neighbors can be generated - insufficient available nodes"
        )
        fake_neighbor_ids = torch.tensor([], device=device)

    fake_neighbors_dict = {}
    idx = 0

    for victim in victim_nodes:
        victim_id = victim.item()
        current_memory = converged_memories_dict[victim_id]

        # Get memory vectors of current neighbors efficiently
        neighbor_mask = edge_index[0] == victim_id
        if neighbor_mask.any():
            neighbors = edge_index[1, neighbor_mask].unique()

            # According to the algorithm, we should compute std from ALL current neighbors
            # not just victim neighbors
            if len(neighbors) > 0:
                # Batch query ALL neighbor memories
                neighbor_memories, _ = model.memory(neighbors)

                # Debug logging
                if debug:
                    logging.info(
                        f"Victim {victim_id}: {len(neighbors)} total neighbors found"
                    )
                    logging.info(f"Neighbor memories shape: {neighbor_memories.shape}")
                    if torch.isnan(neighbor_memories).any():
                        logging.warning(
                            f"NaN detected in neighbor memories for victim {victim_id}"
                        )

                # Compute std across all neighbors
                if len(neighbors) > 1:
                    # Use unbiased std (default) when we have multiple samples
                    memory_std = torch.std(neighbor_memories, dim=0)
                else:
                    # For single neighbor, can't compute meaningful std
                    # Use a fraction of the neighbor's memory magnitude as proxy
                    memory_std = torch.abs(neighbor_memories[0]) * 0.1 + 0.01

                if debug and torch.isnan(memory_std).any():
                    logging.warning(
                        f"NaN in memory_std for victim {victim_id}. This should not happen!"
                    )
            else:
                # No neighbors found - use default
                memory_std = torch.ones_like(current_memory) * 0.1
        else:
            # No neighbors in historical data - use default
            memory_std = torch.ones_like(current_memory) * 0.1

        # Create fake neighbors efficiently
        fake_neighbors = []
        num_gaussian = fake_neighbors_per_victim // 2

        # Gaussian fake memories
        for i in range(num_gaussian):
            if len(fake_neighbor_ids) > 0:
                # Use modulo to allow reuse of neighbor IDs when we have fewer available nodes
                neighbor_idx = idx % len(fake_neighbor_ids)
                fake_neighbor_id = fake_neighbor_ids[neighbor_idx].item()
                fake_memory = torch.normal(0, memory_std)
                fake_neighbors.append((fake_neighbor_id, fake_memory))
                idx += 1

        # Zero fake memories
        for i in range(fake_neighbors_per_victim - num_gaussian):
            if len(fake_neighbor_ids) > 0:
                # Use modulo to allow reuse of neighbor IDs when we have fewer available nodes
                neighbor_idx = idx % len(fake_neighbor_ids)
                fake_neighbor_id = fake_neighbor_ids[neighbor_idx].item()
                fake_memory = torch.zeros_like(current_memory)
                fake_neighbors.append((fake_neighbor_id, fake_memory))
                idx += 1

        fake_neighbors_dict[victim_id] = fake_neighbors

    # Step 5: Define Target Noisy Memory (OPTIMIZED)
    target_memories = {}
    for victim_id, fake_neighbors in fake_neighbors_dict.items():
        if len(fake_neighbors) > 0:
            fake_mems = torch.stack([mem for _, mem in fake_neighbors])
            target_memory = fake_mems.mean(dim=0)
            target_memories[victim_id] = target_memory
        else:
            target_memories[victim_id] = converged_memories_dict[victim_id]

    # Step 6: Solve for Fake Messages (OPTIMIZED)
    msg_dim = data.msg.size(-1)

    # Pre-allocate lists for vectorized operations
    fake_sources_list = []
    fake_destinations_list = []
    fake_messages_list = []

    for victim_id in victim_nodes:
        victim_id = victim_id.item()
        target_memory = target_memories[victim_id]
        current_memory = converged_memories_dict[victim_id]
        memory_diff = target_memory - current_memory

        for fake_neighbor_id, _ in fake_neighbors_dict[victim_id]:
            # Create fake message optimized towards target memory
            fake_msg = torch.randn(msg_dim, device=device) * 0.1
            if msg_dim >= len(memory_diff):
                fake_msg[: len(memory_diff)] += memory_diff * 0.1
            else:
                fake_msg += memory_diff[:msg_dim] * 0.1

            fake_sources_list.append(victim_id)
            fake_destinations_list.append(fake_neighbor_id)
            fake_messages_list.append(fake_msg)

    if len(fake_sources_list) > 0:
        # Vectorized tensor creation
        fake_sources = torch.tensor(fake_sources_list, device=device)
        fake_destinations = torch.tensor(fake_destinations_list, device=device)
        fake_timestamps = torch.full(
            (len(fake_sources_list),),
            attack_timestamp,
            device=device,
            dtype=data.t.dtype,
        )
        fake_messages = torch.stack(fake_messages_list)

        if attack_mode == "single_shot":
            # Single-shot mode: inject all edges at once
            logging.info(
                f"Generated {len(fake_sources)} fake edges for single-shot memstranding attack"
            )

            # Step 7: Inject fake edges into the data stream
            data.src = torch.cat([data.src[:cur_idx], fake_sources, data.src[cur_idx:]])
            data.dst = torch.cat(
                [data.dst[:cur_idx], fake_destinations, data.dst[cur_idx:]]
            )
            data.t = torch.cat([data.t[:cur_idx], fake_timestamps, data.t[cur_idx:]])
            data.msg = torch.cat(
                [data.msg[:cur_idx], fake_messages, data.msg[cur_idx:]]
            )

            # CRITICAL: Update model with fake edges BEFORE restoring original state
            # This ensures the fake edges corrupt the memory as intended
            model.update_memory(
                fake_sources, fake_destinations, fake_timestamps, fake_messages
            )
            model.insert_neighbor(fake_sources, fake_destinations)

            logging.info("MemStranding attack completed - memory corruption injected")

            # Return fake edges for single-shot mode (like other attacks)
            return attack_params, (
                fake_sources,
                fake_destinations,
                fake_timestamps,
                fake_messages,
            )

        elif attack_mode == "distributed":
            # Distributed mode: store all fake edges and select based on budget per batch
            if len(attack_params["original_fake_edges"]) == 0:
                # First time: store all generated fake edges
                attack_params["original_fake_edges"] = [
                    (
                        fake_sources[i],
                        fake_destinations[i],
                        fake_timestamps[i],
                        fake_messages[i],
                    )
                    for i in range(len(fake_sources))
                ]
                logging.info(
                    f"Stored {len(attack_params['original_fake_edges'])} fake edges for distributed attack"
                )

            # Calculate how many edges to inject based on adv_budget
            batch_size = len(pos_src)
            n_edges_to_inject = max(1, int(batch_size * adv_budget))

            # Get available fake edges
            available_edges = attack_params["original_fake_edges"]
            current_idx = attack_params["fake_edge_index"]

            if current_idx < len(available_edges):
                # Select edges for this batch
                end_idx = min(current_idx + n_edges_to_inject, len(available_edges))
                batch_edges = available_edges[current_idx:end_idx]
                attack_params["fake_edge_index"] = end_idx

                # Extract edge components
                batch_sources = torch.tensor([e[0] for e in batch_edges], device=device)
                batch_destinations = torch.tensor(
                    [e[1] for e in batch_edges], device=device
                )
                batch_timestamps = torch.tensor(
                    [e[2] for e in batch_edges], device=device
                )
                batch_messages = torch.stack([e[3] for e in batch_edges])

                logging.info(
                    f"Injecting {len(batch_edges)} fake edges (batch budget: {n_edges_to_inject}) - Progress: {end_idx}/{len(available_edges)}"
                )

                # Inject batch edges into data stream
                data.src = torch.cat(
                    [data.src[:cur_idx], batch_sources, data.src[cur_idx:]]
                )
                data.dst = torch.cat(
                    [data.dst[:cur_idx], batch_destinations, data.dst[cur_idx:]]
                )
                data.t = torch.cat(
                    [data.t[:cur_idx], batch_timestamps, data.t[cur_idx:]]
                )
                data.msg = torch.cat(
                    [data.msg[:cur_idx], batch_messages, data.msg[cur_idx:]]
                )

                # Update model with batch edges
                model.update_memory(
                    batch_sources, batch_destinations, batch_timestamps, batch_messages
                )
                model.insert_neighbor(batch_sources, batch_destinations)
            else:
                logging.info("All fake edges have been injected in distributed mode")
    else:
        logging.warning("No fake edges generated for memstranding attack")
        # Return empty tensors for consistency with other attacks
        empty_sources = torch.tensor([], device=device, dtype=torch.long)
        empty_dests = torch.tensor([], device=device, dtype=torch.long)
        empty_timestamps = torch.tensor([], device=device, dtype=data.t.dtype)
        return attack_params, (empty_sources, empty_dests, empty_timestamps)

    # NOTE: DO NOT restore original state after injection - this would undo the attack!
    # The model should remain in the corrupted state for future predictions


def tgnw_attack(
    model,
    model1,
    data,
    assoc,
    pos_batch,
    attack_type,
    n_adv_edges,
    attack_params,
    device,
    debug,
    neg_batch_list=None,
):
    "some of the attack_params could be modified, for eg. adv_edge"

    pos_src, pos_dst, pos_t, pos_msg = (
        pos_batch.src,
        pos_batch.dst,
        pos_batch.t,
        pos_batch.msg,
    )
    if debug:
        cosine = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        pdist = torch.nn.PairwiseDistance(p=2, eps=0.0)
        msg_cosine_dsts = []
        msg_euclidean_dsts = []

    cur_idx = model.get_current_edge_index()
    t1 = time.time()
    if attack_type == "memstranding":
        attack_params, adv_edges = memstranding_attack(
            model, data, assoc, pos_batch, attack_params, device, debug
        )
        return attack_params, adv_edges
    if attack_type == "fgsm":
        epsilon = attack_params["epsilon"]
        edge_class = 1.0
        # random positive picks
        attack_idx = torch.randperm(pos_src.size(0))[:n_adv_edges]

        # TODO: weakly confident positive picks, strongly confient positive picks
        with torch.enable_grad():
            adv_pos_msg = fgsm_tgnw(
                deepcopy(model),
                deepcopy(data),
                deepcopy(assoc),
                deepcopy(pos_src[attack_idx]),
                deepcopy(pos_dst[attack_idx]),
                deepcopy(pos_t[attack_idx]),
                deepcopy(pos_msg[attack_idx]),
                epsilon,
                edge_class,
                edge_weights=None,
                device=device,
            )

        if debug:
            msg_cosine_dsts.append(
                cosine(pos_msg[attack_idx], adv_pos_msg).mean().item()
            )
            msg_euclidean_dsts.append(
                pdist(pos_msg[attack_idx], adv_pos_msg).mean().item()
            )

        pos_msg[attack_idx] = adv_pos_msg
        return attack_params

    if attack_type == "random":
        is_msg_gaussian = attack_params["is_msg_gaussian"]
        adv_sources, adv_dests, adv_t, adv_msgs = random_attack(
            data, budget=n_adv_edges, device=device, is_msg_gaussian=is_msg_gaussian
        )
    elif attack_type.startswith("negatt"):
        adv_edges = attack_params["adv_test_edges"]
        rel_time_sd = attack_params["train_time_rel_sd"]
        neg_set = attack_params["neg_set"]
        edge_class = 0.0

        cur_t = data.t[cur_idx]

        # adv_sources, adv_dests, adv_t, adv_msgs = add_negative_neighbor(adv_edges, time_sd, n_adv_edges, data.msg.shape[-1], cur_t, device=device)
        adv_edge_components = add_negative_neighbor_wrt_pos_src_neg_sample(
            model,
            data,
            assoc,
            pos_batch,
            n_adv_edges,
            cur_t,
            device,
            adv_edges=adv_edges,
            rel_time_sd=rel_time_sd,
            neg_set=neg_set,
        )
        logging.info(f"neg edge in {time.time()-t1}")
        if adv_edge_components is None:
            print(" no adversarial edge could be found")
            # Return empty tensors for consistency
            empty_sources = torch.tensor([], device=device, dtype=torch.long)
            empty_dests = torch.tensor([], device=device, dtype=torch.long)
            empty_timestamps = torch.tensor([], device=device, dtype=data.t.dtype)
            return attack_params, (empty_sources, empty_dests, empty_timestamps)
        else:
            adv_sources, adv_dests, adv_t, adv_msgs = adv_edge_components
        adv_sources = adv_sources.to(device)
        adv_dests = adv_dests.to(device)
        adv_msgs = adv_msgs.to(device)
        adv_t = adv_t.to(device)

        # add the adversarial edges to the bank
        # batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(adv_sources, adv_dests, adv_t, adv_msgs)}
        # adv_edges = adv_edges | batch_pos
        # # add positives of current batch to the bank and update attack_params
        batch_pos = {
            (s.item(), d.item(), t.item()): m.reshape(1, -1)
            for (s, d, t, m) in zip(pos_src, pos_dst, pos_t, pos_msg)
        }
        adv_edges = adv_edges | batch_pos
        attack_params["adv_test_edges"] = adv_edges

    elif attack_type.startswith("prbcd") or attack_type.startswith("grbcd"):
        # The metric in PRBCD is assumed to be best if lower (like a loss).
        def accuracy(pred, y, *args):
            return (pred.argmax(-1) == y).float().mean()

        def prbcd_metric(*args, **kwargs):
            return -accuracy(*args, **kwargs)

        block_size = attack_params.block_size
        n_attack_epochs = attack_params.n_attack_epochs
        bipartite = attack_params.bipartite
        attack_target = attack_params.attack_target
        rel_time_sd = attack_params.train_time_rel_sd
        neg_init = attack_params.neg_init
        fixed_t_m = attack_params.fixed_t_m
        hist_init = attack_params.hist_init
        hist_ratio = attack_params.hist_ratio
        # attack_target="pos"
        if attack_target == "pos":
            bcd_labels = 1
            bcd_edge_index = torch.cat((pos_src.unsqueeze(0), pos_dst.unsqueeze(0)))
        else:
            use_rand_negs = False
            if neg_init and not hist_init:
                hist_ratio = None
                logging.info(
                    "Attacking with hst_rnd negative edges (same as evaluation)"
                )
                assert neg_batch_list is not None
                pos_list = pos_src.tolist()
                neg_edges = [
                    (pos_list[i], y)
                    for i in range(len(pos_list))
                    for y in neg_batch_list[i]
                ]

                neg_edges = list(set(neg_edges))
                neg_edges = torch.tensor(list(zip(*neg_edges))).to(device)
                logging.info(f"num neg edges: {len(neg_edges)}")
            elif hist_init:

                logging.info(
                    "Attacking with historical edges (from accumulated train+val)"
                )
                adv_edges = attack_params["adv_test_edges"]
                hist_edges = set([(k[0], k[1]) for k in adv_edges])  # v2, v4
                # hist_edges = [(k[0], k[1]) for k in adv_edges]  # v3
                batch_pos = {
                    (s.item(), d.item(), t.item()): m.reshape(1, -1)
                    for (s, d, t, m) in zip(pos_src, pos_dst, pos_t, pos_msg)
                }
                adv_edges = adv_edges | batch_pos
                attack_params["adv_test_edges"] = adv_edges
                if neg_init:
                    logging.info(
                        f"Including random negatives with hist_ratio={hist_ratio}"
                    )
                    batch_pos_src = set([s.item() for s in pos_src])
                    batch_pos_edges = set(
                        [(s.item(), d.item()) for s, d in zip(pos_src, pos_dst)]
                    )

                    # v2, v4
                    hist_edges = set(
                        [
                            (k[0], k[1])
                            for k in list(hist_edges)
                            if k[0] in batch_pos_src
                        ]
                    )
                    hist_edges = hist_edges - batch_pos_edges

                    # v3
                    # hist_edges = [
                    #     (k[0], k[1])
                    #     for k in list(hist_edges)
                    #     if (k[0] in batch_pos_src)
                    #     and ((k[0], k[1]) not in batch_pos_edges)
                    # ]

                    # apply hist_ratio here only
                    use_rand_negs = True
                    if use_rand_negs:
                        n_hist = int(hist_ratio * block_size)
                        # n_random = int((1 - hist_ratio) * block_size)
                        hist_negatives = torch.empty(
                            (2, 0), dtype=torch.int64, device=device
                        )
                        rand_negatives = torch.empty(
                            (2, 0), dtype=torch.int64, device=device
                        )
                        if n_hist > 0:
                            hist_negatives = list(hist_edges)
                            hist_negatives = torch.tensor(
                                list(zip(*hist_negatives))
                            ).to(device)
                            # these are too few, either use them and rest random, or upsample
                            # v2: use as-is, rest random, v3: take list instead of set (above),v4: upsample

                            # idx = torch.randperm(hist_negatives.size(1))[:n_hist] #v2,v3
                            idx = torch.randint(
                                0, hist_negatives.size(1), (n_hist,)
                            )  # v4
                            hist_negatives = hist_negatives[:, idx]

                        n_random = block_size - hist_negatives.size(1)
                        if n_random > 0:
                            first_dst = data.dst.min().item()
                            last_dst = data.dst.max().item()
                            cand_dst_set = np.arange(first_dst, last_dst + 1)

                            n_neg_per_pos = 200

                            for i, pos in enumerate(pos_src):
                                dst_set = np.setdiff1d(
                                    cand_dst_set, [pos.item(), pos_dst[i].item()]
                                )
                                if len(dst_set) > n_neg_per_pos:
                                    pos_dst_set = np.random.choice(
                                        dst_set, n_neg_per_pos, replace=False
                                    )
                                else:
                                    pos_dst_set = dst_set

                                rand_src = torch.full(
                                    (len(pos_dst_set),), pos, device=device
                                )
                                rand_dst = torch.tensor(
                                    pos_dst_set,
                                    device=device,
                                )

                                pos_rand_negs = torch.cat(
                                    [rand_src.unsqueeze(0), rand_dst.unsqueeze(0)],
                                    axis=0,
                                ).to(device)

                                rand_negatives = torch.cat(
                                    [rand_negatives, pos_rand_negs], dim=1
                                )

                            idx = torch.randperm(rand_negatives.size(1))[:n_random]
                            rand_negatives = rand_negatives[:, idx]

                        hist_edges = torch.cat([hist_negatives, rand_negatives], dim=1)
                neg_edges = hist_edges
            else:
                neg_edges = None
            bcd_edge_index = torch.Tensor([])
            bcd_labels = 0

        loss = "prob_margin"  # torch.nn.functional.binary_cross_entropy_with_logits
        t1 = time.time()
        with torch.enable_grad():
            attack_class = PRBCD if attack_type == "prbcd" else GRBCD
            bcd = attack_class(
                model1,
                block_size=block_size,
                data=deepcopy(data),
                epochs=n_attack_epochs,
                epochs_resampling=n_attack_epochs - 50,
                bipartite=bipartite,
                device=device,
                is_undirected=False,
                metric=prbcd_metric,
                rel_time_sd=rel_time_sd,
                edge_label=bcd_labels,
                loss=loss,
                neg_edge_sampling=neg_init or hist_init,
                fixed_t_msg=fixed_t_m,
                log=True,
            )
            perturbed_edge_index, flipped_edges, adv_t, adv_msgs = bcd.attack(
                deepcopy(pos_batch),
                edge_index=deepcopy(bcd_edge_index),
                budget=n_adv_edges,
                assoc=deepcopy(assoc),
                neg_edges=neg_edges,
                hist_ratio=hist_ratio if (hist_init and not use_rand_negs) else None,
            )
        logging.info(f"attack crafted in {time.time()-t1}")

        flipped_edges = flipped_edges.detach()
        adv_sources = flipped_edges[0].to(int)
        adv_dests = flipped_edges[1].to(int)

        cur_idx = model.neighbor_loader.cur_e_id
        # msg_idx = torch.randperm(cur_idx)[: adv_sources.size(0)]
        # adv_msgs = data.msg[msg_idx]

        # adv_t = torch.tensor(pos_t[0].cpu() + np.random.normal(scale=rel_time_sd, size=(adv_sources.size(0))).round(),dtype=data.t.dtype,device=device)
        logging.info(f"{len(adv_sources)} edges flipped")
        if len(adv_sources) == 0:
            # Return empty tensors for consistency
            empty_sources = torch.tensor([], device=device, dtype=torch.long)
            empty_dests = torch.tensor([], device=device, dtype=torch.long)
            empty_timestamps = torch.tensor([], device=device, dtype=data.t.dtype)
            return (
                attack_params,
                (empty_sources, empty_dests, empty_timestamps),
                (adv_sources, adv_dests, adv_t),
            )

    adv_e_weights = torch.ones(
        adv_sources.size(0), dtype=data.edge_weights.dtype, device=device
    )
    # negatt-fgsm specific
    if attack_type.endswith("fgsm"):
        t1 = time.time()
        epsilon = attack_params["epsilon"]
        edge_class = 1.0
        if debug:
            old_adv_msgs = deepcopy(adv_msgs)
        with torch.enable_grad():
            adv_msgs = fgsm_tgnw(
                deepcopy(model),
                deepcopy(data),
                deepcopy(assoc),
                deepcopy(adv_sources),
                deepcopy(adv_dests),
                deepcopy(adv_t),
                deepcopy(adv_msgs),
                epsilon,
                edge_class,
                adv_e_weights,
                device,
            )
        if debug:
            msg_cosine_dsts.append(cosine(old_adv_msgs, adv_msgs).mean().item())
            msg_euclidean_dsts.append(pdist(old_adv_msgs, adv_msgs).mean().item())
        logging.info(f"FGSM in {time.time()-t1}")

    data.src = torch.cat([data.src[:cur_idx], adv_sources, data.src[cur_idx:]])
    data.dst = torch.cat([data.dst[:cur_idx], adv_dests, data.dst[cur_idx:]])
    data.t = torch.cat([data.t[:cur_idx], adv_t, data.t[cur_idx:]])
    data.msg = torch.cat([data.msg[:cur_idx], adv_msgs, data.msg[cur_idx:]])
    data.edge_weights = torch.cat(
        [data.edge_weights[:cur_idx], adv_e_weights, data.edge_weights[cur_idx:]]
    )

    # Attack TGN
    model.update_memory(adv_sources, adv_dests, adv_t, adv_msgs, adv_e_weights)
    model.insert_neighbor(adv_sources, adv_dests)
    logging.info(f"attacked in {time.time()-t1}")

    return attack_params, (adv_sources, adv_dests, adv_t, adv_msgs)


def nat_attack(
    model,
    pos_batch,
    attack_type,
    n_adv_edges,
    attack_params,
    device,
    neg_batch_list=None,
    debug=False,
):
    "some of the attack_params could be modified, for eg. adv_edge"
    # here msg corresponds to e_id
    pos_src, pos_dst, pos_t, pos_eid = pos_batch  # start with 133854 for wiki
    cur_idx = min(pos_eid)
    cur_t = min(pos_t)
    current_e_feats = model.e_feat_th.data
    pos_msg = torch.cat([current_e_feats[e_id].view(1, -1) for e_id in pos_eid])
    t0 = time.time()
    if attack_type == "fgsm":
        epsilon = attack_params["epsilon"]
        edge_class = 1.0
        # random positive picks
        attack_idx = torch.randperm(len(pos_src))[:n_adv_edges]
        # TODO: weakly confident positive picks
        with torch.enable_grad():
            adv_pos_msg = fgsm_nat(
                deepcopy(model),
                deepcopy(pos_batch),
                deepcopy(pos_src[attack_idx]),
                deepcopy(pos_dst[attack_idx]),
                deepcopy(pos_t[attack_idx]),
                deepcopy(pos_msg[attack_idx]),
                epsilon,
                edge_class,
            )

        pos_msg[attack_idx] = adv_pos_msg

        current_e_feats[pos_eid] = pos_msg
        model.e_feat_th.data = current_e_feats
        model.edge_raw_embed = torch.nn.Embedding.from_pretrained(
            model.e_feat_th, padding_idx=0, freeze=True
        )
        return attack_params

    if attack_type.startswith("negatt"):
        # negatt
        adv_edges = attack_params["adv_test_edges"]
        rel_time_sd = attack_params["train_time_rel_sd"]
        neg_set = attack_params["neg_set"]
        edge_class = 0.0

        adv_edge_components = add_negative_neighbor_wrt_pos_src_neg_sample_nat(
            pos_batch,
            n_adv_edges,
            cur_t,
            device,
            adv_edges=adv_edges,
            rel_time_sd=rel_time_sd,
            neg_set=neg_set,
        )

        if adv_edge_components is None:
            print(" no adversarial edge could be found")
            # Return empty tensors for consistency
            empty_sources = torch.tensor([], device=device, dtype=torch.long)
            empty_dests = torch.tensor([], device=device, dtype=torch.long)
            empty_timestamps = torch.tensor([], device=device, dtype=data.t.dtype)
            return attack_params, (empty_sources, empty_dests, empty_timestamps)
        else:
            adv_sources, adv_dests, adv_t, adv_msgs = adv_edge_components

        # add positives of current batch to the bank and update attack_params
        batch_pos = {
            (s, d): m.reshape(1, -1)
            for (s, d, m) in zip(pos_src, pos_dst, pos_msg.detach().cpu().numpy())
        }
        adv_edges = adv_edges | batch_pos
        attack_params["adv_test_edges"] = adv_edges

    elif attack_type.startswith("prbcd") or attack_type.startswith("grbcd"):
        src_node_ids = attack_params["src_node_ids"]
        dst_node_ids = attack_params["dst_node_ids"]
        rel_time_sd = attack_params["train_time_rel_sd"]

        # make copy of model and load dict instead of deepcopy
        # The metric in PRBCD is assumed to be best if lower (like a loss).
        def accuracy(pred, y, *args):
            return (pred.argmax(-1) == y).float().mean()

        def prbcd_metric(*args, **kwargs):
            return -accuracy(*args, **kwargs)

        block_size = attack_params.block_size
        n_attack_epochs = attack_params.n_attack_epochs
        bipartite = attack_params.bipartite
        attack_target = attack_params.attack_target
        rel_time_sd = attack_params.train_time_rel_sd
        neg_init = attack_params.neg_init
        fixed_t_m = attack_params.fixed_t_m
        if attack_target == "pos":
            bcd_labels = 1
            bcd_edge_index = torch.tensor(np.stack((pos_src, pos_dst)))
        else:
            if neg_init:
                assert neg_batch_list is not None
                # initialize with negatt
                pos_list = pos_src.tolist()
                neg_edges = [
                    (pos_list[i], y)
                    for i in range(len(pos_list))
                    for y in neg_batch_list[i]
                ]
                neg_edges = list(set(neg_edges))
                neg_edges = torch.tensor(list(zip(*neg_edges))).to(device)
            else:
                neg_edges = None
            bcd_edge_index = torch.Tensor([])
            bcd_labels = 0
        loss = "prob_margin"  # torch.nn.functional.binary_cross_entropy_with_logits
        t1 = time.time()

        with torch.enable_grad():
            attack_class = PRBCD_NAT if attack_type == "prbcd" else GRBCD_NAT
            bcd = attack_class(
                deepcopy(model),
                block_size,
                src_node_ids,
                dst_node_ids,
                cur_e_id=cur_idx,
                cur_t=cur_t,
                epochs=n_attack_epochs,
                epochs_resampling=n_attack_epochs - 50,
                bipartite=bipartite,
                device=device,
                is_undirected=False,
                metric=prbcd_metric,
                rel_time_sd=rel_time_sd,
                edge_label=bcd_labels,
                log=True,
                loss=loss,
                neg_edge_sampling=neg_init,
                fixed_t_msg=fixed_t_m,
                pos_batch=pos_batch,
            )
            perturbed_edge_index, flipped_edges, adv_t, adv_msgs = bcd.attack(
                edge_index=deepcopy(bcd_edge_index),
                budget=n_adv_edges,
                neg_edges=neg_edges,
            )

        logging.info(f"attack crafted in {time.time()-t1}")

        flipped_edges = flipped_edges.detach()
        adv_sources = flipped_edges[0].to(int)
        adv_dests = flipped_edges[1].to(int)
        logging.info(f"{len(adv_sources)} edges found")
        if len(adv_sources) == 0:
            logging.info("could not flip edges")
            # Return empty tensors for consistency
            empty_sources = torch.tensor([], device=device, dtype=torch.long)
            empty_dests = torch.tensor([], device=device, dtype=torch.long)
            empty_timestamps = torch.tensor([], device=device, dtype=data.t.dtype)
            return attack_params, (empty_sources, empty_dests, empty_timestamps)

        # craft adv msg
        if adv_msgs is None:
            msg_idx = torch.randperm(cur_idx).to(device)[: len(adv_sources)]
            adv_msgs = model.edge_raw_embed(msg_idx)
        if adv_t is None:
            adv_t = np.array(
                cur_t
                + np.random.normal(scale=rel_time_sd, size=(len(adv_sources))).round(),
                dtype=np.int32,
            )

    current_weights = model.get_weights()
    adv_e_weights = torch.ones((len(adv_sources), 1)).to(device)
    if attack_type.endswith("fgsm"):
        t1 = time.time()
        epsilon = attack_params["epsilon"]
        with torch.enable_grad():
            adv_msgs = fgsm_nat(
                deepcopy(model),
                deepcopy(pos_batch),
                deepcopy(adv_sources),
                deepcopy(adv_dests),
                deepcopy(adv_t),
                deepcopy(adv_msgs),
                epsilon,
                edge_class,
            )

        logging.info(f"FGSM in {time.time()-t1}")
        # add adv edges to embedding

    updated_e_feats = torch.cat(
        [
            current_e_feats[:cur_idx],
            torch.tensor(adv_msgs, device=device, dtype=current_e_feats.dtype),
            current_e_feats[cur_idx:],
        ]
    )
    model.e_feat_th.data = updated_e_feats
    model.edge_raw_embed = torch.nn.Embedding.from_pretrained(
        model.e_feat_th, padding_idx=0, freeze=True
    )

    updated_weights = torch.cat(
        [current_weights[:cur_idx], adv_e_weights, current_weights[cur_idx:]]
    )
    model.set_weights(updated_weights)

    adv_e_id = np.arange(len(adv_sources)) + cur_idx
    pos_eid += len(adv_sources)

    # Attack - NAT
    model.contrast_modified(
        adv_sources, adv_dests, adv_t, adv_e_id, pos_edge=True, test=True
    )

    logging.info(f"attacked in {time.time()-t0}")

    return attack_params
