import logging

import time
from copy import deepcopy


import numpy as np
import torch

from torch_geometric.loader import TemporalDataLoader


from utils.attack import tgnw_attack


def train_one_epoch(
    model, data, train_mask, batch_size, optimizer, criterion, assoc, device
):
    r"""
    Training procedure for TGN model
    This function uses some objects that are globally defined in the current scrips

    Parameters:
        None
    Returns:
        None

    """
    model.train()

    train_data = data[train_mask]
    train_loader = TemporalDataLoader(train_data, batch_size=batch_size)

    # Ensure to only sample actual destination nodes as negatives.
    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

    model.memory.reset_state()  # Start with a fresh memory.
    model.neighbor_loader.reset_state()  # Start with an empty graph.

    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        pos_src, pos_dst, t, msg, weight = (
            batch.src,
            batch.dst,
            batch.t,
            batch.msg,
            batch.edge_weights,
        )

        # Sample negative destination nodes.
        neg_dst = torch.randint(
            min_dst_idx,
            max_dst_idx + 1,
            (pos_src.size(0),),
            dtype=torch.long,
            device=device,
        )

        n_pos = pos_dst.size(0)
        src = pos_src.repeat(2)
        dst = torch.concat([pos_dst, neg_dst], axis=0)
        assert src.shape == dst.shape
        edge_index = torch.cat([src.unsqueeze(0), dst.unsqueeze(0)], axis=0).to(device)
        edge_weight = (
            None  # torch.ones_like(src) # should it be made 1 also for  negative
        )
        y_pred = model(data, edge_index, edge_weight, assoc)

        pos_out = y_pred[:n_pos]
        neg_out = y_pred[n_pos:]

        loss = criterion(pos_out, torch.ones_like(pos_out))
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        # Update memory and neighbor loader with ground-truth state.
        model.memory.update_state(pos_src, pos_dst, t, msg, weight)
        model.neighbor_loader.insert(pos_src, pos_dst)

        loss.backward()
        optimizer.step()
        model.memory.detach()
        total_loss += float(loss) * batch.num_events

    return total_loss / train_data.num_events


@torch.no_grad()
def test(
    model,
    data,
    mask,
    batch_size,
    neg_sampler,
    split_mode,
    assoc,
    metric,
    evaluator,
    device,
    debug=False,
    attack_params=None,
    detect_anomalies=False,
):
    r"""
    Evaluated the dynamic link prediction
    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges

    Parameters:
        loader: an object containing positive attributes of the positive edges of the evaluation set
        neg_sampler: an object that gives the negative edges corresponding to each positive edge
        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
    Returns:
        perf_metric: the result of the performance evaluaiton
    """
    model.eval()
    # approach 1 for scaling: use fixed attacker model different from the victim model
    model1 = deepcopy(model)
    full_data = deepcopy(data)
    split_data = full_data[mask]
    loader = TemporalDataLoader(split_data, batch_size=batch_size)

    perf_list = []
    logging.info(f"Attack params: {type(attack_params)}")

    # Initialize edge tracker and anomaly detector (if enabled)
    edge_tracker = None
    anomaly_detector = None
    if detect_anomalies:
        from utils.edge_tracker import EdgeTracker, UnifiedAnomalyDetector

        dataset_name = (
            neg_sampler.dataset_name
            if hasattr(neg_sampler, "dataset_name")
            else "unknown"
        )
        attack_type_string = f"tncnw-{attack_params.attack_type}-{attack_params.attack_mode}-B_PCT{attack_params.budget_pct}-FN{attack_params.fake_neighbors_per_victim}-ADV_B{attack_params.adv_budget}"
        edge_tracker = EdgeTracker(dataset_name, attack_type_string)
        anomaly_detector = UnifiedAnomalyDetector()
        logging.info(
            f"Anomaly detection enabled for dataset: {dataset_name}, attack: {attack_type_string}"
        )

        # Preload training and validation edges for better anomaly detection context
        logging.info(
            "Preloading training and validation edges for anomaly detection..."
        )

        num_preload_batches = 0
        if split_mode == "test":
            train_val_mask = ~mask
            train_val_edges = data[train_val_mask]

            # Split train/val edges into 200-sized batches and assign batch IDs
            batch_size = 200
            num_edges = len(train_val_edges.src)
            num_preload_batches = (
                num_edges + batch_size - 1
            ) // batch_size  # Ceiling division

            for batch_idx in range(num_preload_batches):
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, num_edges)

                edge_tracker.set_current_batch_id(batch_idx)
                edge_tracker.add_normal_edges(
                    train_val_edges.src[start_idx:end_idx],
                    train_val_edges.dst[start_idx:end_idx],
                    train_val_edges.t[start_idx:end_idx],
                    (
                        train_val_edges.msg[start_idx:end_idx]
                        if hasattr(train_val_edges, "msg")
                        else None
                    ),
                    (
                        train_val_edges.edge_weights[start_idx:end_idx]
                        if hasattr(train_val_edges, "edge_weights")
                        else None
                    ),
                )

            logging.info(
                f"Loaded {num_edges} training edges across {num_preload_batches} batches of size {batch_size}"
            )

        logging.info(
            f"Total edges preloaded: {edge_tracker.total_normal_edges} including {len(edge_tracker.normal_messages)} normal messages"
        )

    if attack_params is not None:
        for k in attack_params.keys():
            if k == "neg_set" or k == "adv_test_edges":
                logging.info(f"{k}: {len(attack_params[k])}")
            else:
                logging.info(f"{k}: {attack_params[k]}")

    if attack_params:  # att_pos=front
        attack_type = attack_params.attack_type
        # Use budget_pct for MemStranding, adv_budget for other attacks
        if attack_type == "memstranding":
            budget = attack_params.get("budget_pct", 0.01)
        else:
            budget = attack_params.adv_budget
        attack_times = []
        memstranding_executed = False  # Track if MemStranding has been executed

    plot = False
    if plot:
        pos_prob_list = []
        pos_logit_list = []
        breakpoint()
    log = False
    if log:
        test_edges = []
    # Initialize batch counter for spotlight - start after preloaded batches
    batch_id = num_preload_batches if detect_anomalies else 0
    for pos_batch in loader:
        pos_src, pos_dst, pos_t, pos_msg, pos_weights = (
            pos_batch.src,
            pos_batch.dst,
            pos_batch.t,
            pos_batch.msg,
            pos_batch.edge_weights,
        )
        neg_batch_list = neg_sampler.query_batch(
            pos_src, pos_dst, pos_t, split_mode=split_mode
        )
        # Special handling for MemStranding - execute only once at the beginning
        should_attack = False
        if attack_params and (n_adv_edges := round(len(pos_src) * budget)) > 0:
            if attack_type == "memstranding":
                attack_mode = attack_params.get("attack_mode", "single_shot")
                if attack_mode == "single_shot":
                    # Single-shot mode: only execute once at the beginning of test phase
                    if not memstranding_executed:
                        should_attack = True
                        memstranding_executed = True
                        logging.info(
                            "Executing MemStranding attack (single-shot at test start)"
                        )
                elif attack_mode == "distributed":
                    # Distributed mode: execute at every batch
                    should_attack = True
            else:
                # Other attacks execute at every batch as before
                should_attack = True

        if should_attack:  # att_pos=front
            model1.load_state_dict(model.state_dict())
            model1.neighbor_loader = deepcopy(model.neighbor_loader)
            t1 = time.time()
            logging.info("Attacking")
            attack_params, adv_edges = tgnw_attack(
                model,
                model1,
                full_data,
                assoc,
                pos_batch,
                attack_type,
                n_adv_edges,
                attack_params,
                device,
                debug,
                neg_batch_list=neg_batch_list,
            )
            attack_times.append(time.time() - t1)
            logging.info(f"Attacked in {attack_times[-1]}")

            # Track attack edges for anomaly detection
            # Track attack edges for anomaly detection
            if edge_tracker is not None and adv_edges:
                edge_tracker.set_current_batch_id(batch_id)
                adv_src, adv_dst, adv_t, adv_msg = adv_edges
                edge_tracker.add_attack_edges(
                    adv_src, adv_dst, adv_t, msg=adv_msg, attack_type=attack_type
                )

            if log:
                adv_src, adv_dst, adv_t, adv_msg = adv_edges
                test_edges.extend(
                    [
                        (adv_src[i].item(), adv_dst[i].item(), adv_t[i].item(), "adv")
                        for i in range(len(adv_src))
                    ]
                )

        # neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)

        for idx, neg_batch in enumerate(neg_batch_list):
            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
            dst = torch.tensor(
                np.concatenate(
                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
                    axis=0,
                ),
                device=device,
            )

            edge_index = torch.cat([src.unsqueeze(0), dst.unsqueeze(0)], axis=0).to(
                device
            )
            edge_weight = None
            y_pred = model(full_data, edge_index, edge_weight, assoc)

            # compute MRR
            input_dict = {
                "y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
                "y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
                "eval_metric": [metric],
            }
            perf_list.append(evaluator.eval(input_dict)[metric])
            if plot:
                pos_logit_list.append(y_pred[0, :].cpu().item())
                pos_prob_list.append(y_pred[0, :].sigmoid().cpu().item())

        # Track normal edges for anomaly detection
        if edge_tracker is not None:
            edge_tracker.set_current_batch_id(batch_id)
            edge_tracker.add_normal_edges(pos_src, pos_dst, pos_t, pos_msg, pos_weights)

        # Update memory and neighbor loader with ground-truth state.
        model.update_memory(pos_src, pos_dst, pos_t, pos_msg, edge_weight=pos_weights)
        model.insert_neighbor(pos_src, pos_dst)

        batch_id += 1  # Increment batch ID for next iteration

    if attack_params:
        logging.info(
            f"average attack time for the run: {np.mean(attack_times)} for {len(attack_times)} attacks."
        )
    perf_metrics = float(torch.tensor(perf_list).mean())

    # Run unified anomaly detection if enabled
    if edge_tracker is not None and anomaly_detector is not None:
        logging.info("Running unified anomaly detection...")

        # Run unified analysis (both structural and feature-based) with visualization
        is_bipartite = attack_params.get("bipartite", False) if attack_params else False
        detection_results = anomaly_detector.analyze(
            edge_tracker, generate_plot=True, is_bipartite=is_bipartite
        )

        # Log key results
        if "combined_analysis" in detection_results:
            combined = detection_results["combined_analysis"]
            logging.info(f"Unified Detection Results:")
            logging.info(
                f"  - Structural detection rate: {combined['structural_detection_rate']:.3f}"
            )
            logging.info(
                f"  - Feature anomaly rate: {combined['feature_anomaly_rate']:.3f}"
            )
            logging.info(
                f"  - Ensemble score: {combined['equal_weight_ensemble_score']:.3f}"
            )
            logging.info(f"  - Confidence: {combined['detection_confidence']}")
            logging.info(f"  - Recommendation: {combined['recommendation']}")
        else:
            logging.info("No attacks detected in this run - skipping anomaly detection")

    if log:
        import pandas as pd
        from datetime import datetime as dt

        df = pd.DataFrame(test_edges, columns=["src", "dst", "t", "type"])
        df.to_csv(
            f"test_edges/test_edges__tncnw-{neg_sampler.dataset_name}-{dt.now()}.csv"
        )
    if plot:
        import seaborn as sns

        ax = sns.scatterplot(x=pos_prob_list, y=perf_list)
        ax.set_xlabel("y_pred")
        ax.set_ylabel("mrr")
        ax.set_title(f"TNCNW mrr vs prob. on {neg_sampler.dataset_name} dataset")
        fig = ax.get_figure()
        fig.savefig(f"images/probVsMRR_tncn_{neg_sampler.dataset_name}")
        ax.clear()

        ax1 = sns.scatterplot(x=pos_logit_list, y=perf_list)
        ax1.set_xlabel("y_pred")
        ax1.set_ylabel("mrr")
        ax1.set_title(f"TNCNW mrr vs logit on {neg_sampler.dataset_name} dataset")
        fig = ax1.get_figure()
        fig.savefig(f"images/LogitVsMRR_tncn_{neg_sampler.dataset_name}")
        ax1.clear()

    return perf_metrics
