

from copy import deepcopy
import sys, traceback
import numpy as np
import torch
from  utils.adv_attacks import add_negative_neighbor

from torch_geometric.loader import TemporalDataLoader

def train_one_epoch(model, data, train_mask, batch_size, neighbor_loader, optimizer, 
          criterion, assoc, device):
    r"""
    Training procedure for DyRep model
    This function uses some objects that are globally defined in the current scrips 

    Parameters:
        None
    Returns:
        None
            
    """
    model['memory'].train()
    model['gnn'].train()
    model['link_pred'].train()

    train_data = data[train_mask]
    train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
    
    # Ensure to only sample actual destination nodes as negatives.
    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

    model['memory'].reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.

    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        # Sample negative destination nodes.
        neg_dst = torch.randint(
            min_dst_idx,
            max_dst_idx + 1,
            (src.size(0),),
            dtype=torch.long,
            device=device,
        )

        n_id = torch.cat([src, pos_dst, neg_dst]).unique()
        n_id, edge_index, e_id = neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get updated memory of all nodes involved in the computation.
        z, last_update = model['memory'](n_id)

        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])

        loss = criterion(pos_out, torch.ones_like(pos_out))
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        # update the memory with ground-truth
        z = model['gnn'](
            z,
            last_update,
            edge_index,
            data.t[e_id].to(device),
            data.msg[e_id].to(device),
        )
        model['memory'].update_state(src, pos_dst, t, msg, z, assoc)

        # update neighbor loader
        neighbor_loader.insert(src, pos_dst)

        loss.backward()
        optimizer.step()
        model['memory'].detach()
        total_loss += float(loss) * batch.num_events

    return total_loss / train_data.num_events


@torch.no_grad()
def test(model, data, mask, batch_size, neighbor_loader, neg_sampler, split_mode, assoc, metric, evaluator, device, adv_edges_param=None):
    r"""
    Evaluated the dynamic link prediction
    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges

    Parameters:
        loader: an object containing positive attributes of the positive edges of the evaluation set
        neg_sampler: an object that gives the negative edges corresponding to each positive edge
        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
    Returns:
        perf_metric: the result of the performance evaluaiton
    """
    model['memory'].eval()
    model['gnn'].eval()
    model['link_pred'].eval()
    full_data = deepcopy(data)
    split_data = full_data[mask]
    loader = TemporalDataLoader(split_data, batch_size=batch_size)

    perf_list = []

    for pos_batch in loader:
        pos_src, pos_dst, pos_t, pos_msg = (
            pos_batch.src,
            pos_batch.dst,
            pos_batch.t,
            pos_batch.msg,
        )

        # attack with fake pos edges
        if adv_edges_param:
            adv_edges, time_sd = adv_edges_param
            # adv attack with -ve edges
            n = round(len(pos_src)*0.05)
            adv_sources, adv_dests, adv_t, adv_msgs = add_negative_neighbor(adv_edges, time_sd, n, pos_msg.shape[-1], pos_t[0], device=device)

            adv_sources = adv_sources.to(device)
            adv_dests = adv_dests.to(device)
            adv_msgs = adv_msgs.to(device)
            adv_t = adv_t.to(device)

            # add the adv edges to self.data
            cur_idx = neighbor_loader.cur_e_id
            full_data.src = torch.cat([full_data.src[:cur_idx], adv_sources, full_data.src[cur_idx:]])
            full_data.dst = torch.cat([full_data.dst[:cur_idx], adv_dests, full_data.dst[cur_idx:]])
            full_data.t = torch.cat([full_data.t[:cur_idx], adv_t, full_data.t[cur_idx:]])
            full_data.msg = torch.cat([full_data.msg[:cur_idx], adv_msgs, full_data.msg[cur_idx:]])

            # add current batch edges to the bank
            batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(adv_sources, adv_dests, adv_t, adv_msgs)}
            adv_edges = adv_edges | batch_pos

            # memory udpatge
            n_id = torch.cat([adv_sources, adv_dests]).unique()
            n_id, edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=device)
            z, last_update = model['memory'](n_id)
            
            z = model['gnn'](
                z,
                last_update,
                edge_index,
                full_data.t[e_id].to(device),
                full_data.msg[e_id].to(device),
            )
            model['memory'].update_state(adv_sources, adv_dests, adv_t, adv_msgs, z, assoc)
            
            # update the neighbor loader
            neighbor_loader.insert(adv_sources, adv_dests)

        
        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)

        for idx, neg_batch in enumerate(neg_batch_list):
            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
            dst = torch.tensor(
                np.concatenate(
                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
                    axis=0,
                ),
                device=device,
            )

            n_id = torch.cat([src, dst]).unique()
            n_id, edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=device)

            # Get updated memory of all nodes involved in the computation.
            z, last_update = model['memory'](n_id)

            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
            # compute MRR
            input_dict = {
                "y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
                "y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
                "eval_metric": [metric],
            }
            perf_list.append(evaluator.eval(input_dict)[metric])
        
        if False: #adv_edges_param:
            adv_edges, time_sd = adv_edges_param
            # adv attack with -ve edges
            n = round(len(pos_src)*0.05)
            adv_sources, adv_dests, adv_t, adv_msgs = add_negative_neighbor(adv_edges, time_sd, n, pos_msg.shape[-1], pos_t[-1], device=device)

            pos_src = torch.cat([pos_src, adv_sources]).to(device)
            pos_dst = torch.cat([pos_dst, adv_dests]).to(device)
            pos_msg = torch.cat([pos_msg, adv_msgs]).to(device)
            pos_t = torch.cat([pos_t, adv_t]).to(device)

            # add the adv edges to self.data
            cur_idx = neighbor_loader.cur_e_id
            full_data.src = torch.cat([full_data.src[:cur_idx], adv_sources, full_data.src[cur_idx:]])
            full_data.dst = torch.cat([full_data.dst[:cur_idx], adv_dests, full_data.dst[cur_idx:]])
            full_data.t = torch.cat([full_data.t[:cur_idx], adv_t, full_data.t[cur_idx:]])
            full_data.msg = torch.cat([full_data.msg[:cur_idx], adv_msgs, full_data.msg[cur_idx:]])

            # add current batch edges to the bank
            batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(pos_src, pos_dst, pos_t, pos_msg)}
            adv_edges = adv_edges | batch_pos
        try:
            # update the memory with positive edges
            n_id = torch.cat([pos_src, pos_dst]).unique()
            n_id, edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=device)
            z, last_update = model['memory'](n_id)
            
            z = model['gnn'](
                z,
                last_update,
                edge_index,
                full_data.t[e_id].to(device),
                full_data.msg[e_id].to(device),
            )
            model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
            
            # update the neighbor loader
            neighbor_loader.insert(pos_src, pos_dst)
        except Exception as e:
            print("-"*60)
            traceback.print_exc(file=sys.stdout)
            print("-"*60)
            
            print(z.shape, last_update.shape, edge_index.shape, full_data.t[e_id].shape, full_data.msg[e_id].shape)
            print(pos_src.shape, pos_dst.shape, pos_t.shape, pos_msg.shape)

    perf_metric = float(torch.tensor(perf_list).mean())

    return perf_metric


# TODO: check collision
# TODO: include all train,val edges encountered so far