import logging

import time
from tqdm import tqdm
from copy import deepcopy


import numpy as np
import torch

from torch_geometric.loader import TemporalDataLoader

from utils.attack import tgn_attack

def train_one_epoch(model, data, train_mask, batch_size, optimizer, 
          criterion, assoc, device):
    r"""
    Training procedure for TNCN model
    This function uses some objects that are globally defined in the current scrips 

    Parameters:
        None
    Returns:
        None
            
    """

    model.train()

    train_data = data[train_mask]
    train_loader = TemporalDataLoader(train_data, batch_size=batch_size)

    # Ensure to only sample actual destination nodes as negatives.
    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

    model.memory.reset_state()  # Start with a fresh memory.
    model.neighbor_loader.reset_state()  # Start with an empty graph.

    total_loss = 0

    for batch in tqdm(train_loader,disable=True):
        batch = batch.to(device)
        optimizer.zero_grad()

        pos_src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        # Sample negative destination nodes.
        neg_dst = torch.randint(
            min_dst_idx,
            max_dst_idx + 1,
            (pos_src.size(0),),
            dtype=torch.long,
            device=device,
        )

        n_pos = pos_dst.size(0)
        src = pos_src.repeat(2)
        dst = torch.concat([pos_dst, neg_dst], axis=0)
        assert src.shape == dst.shape
        y_pred = model(data, src, dst, assoc)
        pos_out = y_pred[ : n_pos]
        neg_out = y_pred[n_pos :]        
        
        loss = criterion(pos_out, torch.ones_like(pos_out))
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        # Update memory and neighbor loader with ground-truth state.
        model.memory.update_state(pos_src, pos_dst, t, msg)
        model.neighbor_loader.insert(pos_src, pos_dst)

        loss.backward()
        optimizer.step()
        model.memory.detach()
        total_loss += float(loss) * batch.num_events

    return total_loss / train_data.num_events


@torch.no_grad()
def test(model, data, mask, batch_size, neg_sampler, split_mode, assoc, metric, evaluator, device, debug=False, attack_params=None):
    r"""
    Evaluated the dynamic link prediction
    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges

    Parameters:
        loader: an object containing positive attributes of the positive edges of the evaluation set
        neg_sampler: an object that gives the negative edges corresponding to each positive edge
        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
    Returns:
        perf_metric: the result of the performance evaluaiton
    """
    model.eval()
    full_data = deepcopy(data)
    split_data = full_data[mask]
    loader = TemporalDataLoader(split_data, batch_size=batch_size)

    perf_list = []
    logging.info(f"Attack params: {attack_params}")
    if attack_params : #att_pos=front
        attack_type = attack_params.attack_type
        # Use budget_pct for MemStranding, adv_budget for other attacks
        if attack_type == "memstranding":
            budget = attack_params.get("budget_pct", 0.01)
        else:
            budget = attack_params.adv_budget
        attack_times = []
        memstranding_executed = False  # Track if MemStranding has been executed

    model1 = deepcopy(model)
    for pos_batch in tqdm(loader,disable=True):
        pos_src, pos_dst, pos_t, pos_msg = (
            pos_batch.src,
            pos_batch.dst,
            pos_batch.t,
            pos_batch.msg,
        )
        
        # Special handling for MemStranding - execute only once at the beginning
        should_attack = False
        if attack_params and (n_adv_edges:= round(len(pos_src) * budget)) > 0:
            if attack_type == "memstranding":
                attack_mode = attack_params.get("attack_mode", "single_shot")
                if attack_mode == "single_shot":
                    # Single-shot mode: only execute once at the beginning of test phase
                    if not memstranding_executed:
                        should_attack = True
                        memstranding_executed = True
                        logging.info("Executing MemStranding attack (single-shot at test start)")
                elif attack_mode == "distributed":
                    # Distributed mode: execute at every batch
                    should_attack = True
            else:
                # Other attacks execute at every batch as before
                should_attack = True
        
        if should_attack: #att_pos=front
            t1 = time.time()
            logging.info("Attacking")
            attack_params = tgn_attack(model, model1, full_data, assoc,  pos_batch, attack_type, n_adv_edges, attack_params, device, debug)
            attack_times.append(time.time()-t1)
            logging.info(f"Attacked in {attack_times[-1]}")

        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)

        for idx, neg_batch in enumerate(neg_batch_list):
            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
            dst = torch.tensor(
                np.concatenate(
                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
                    axis=0,
                ),
                device=device,
            )
            y_pred = model(full_data, src, dst, assoc)

            
            # compute MRR
            input_dict = {
                "y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
                "y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
                "eval_metric": [metric],
            }
            perf_list.append(evaluator.eval(input_dict)[metric])

        # Update memory and neighbor loader with ground-truth state.
        model.update_memory(pos_src, pos_dst, pos_t, pos_msg)
        model.insert_neighbor(pos_src, pos_dst)
    
    if attack_params:
        logging.info(f"average attack time for the run: {np.mean(attack_times)} for {len(attack_times)} attacks.")
    perf_metrics = float(torch.tensor(perf_list).mean())

    return perf_metrics
