from copy import deepcopy
import numpy as np
import torch
from  utils.adv_attacks import add_negative_neighbor, add_negative_neighbor_wrt_pos, fgsm, PRBCD
import logging
from torch_geometric.loader import TemporalDataLoader
import random


def train_one_epoch(model, data, train_mask, batch_size, neighbor_loader, optimizer, 
          criterion, assoc, device):
    r"""
    Training procedure for TGN model
    This function uses some objects that are globally defined in the current scrips 

    Parameters:
        None
    Returns:
        None
            
    """

    model['memory'].train()
    model['gnn'].train()
    model['link_pred'].train()
    
    train_data = data[train_mask]
    train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
    
    # Ensure to only sample actual destination nodes as negatives.
    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
    
    model['memory'].reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.

    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        # Sample negative destination nodes.
        neg_dst = torch.randint(
            min_dst_idx,
            max_dst_idx + 1,
            (src.size(0),),
            dtype=torch.long,
            device=device,
        )

        n_id = torch.cat([src, pos_dst, neg_dst]).unique()
        n_id, edge_index, e_id = neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get updated memory of all nodes involved in the computation.
        z, last_update = model['memory'](n_id)
        z = model['gnn'](
            z,
            last_update,
            edge_index,
            data.t[e_id].to(device),
            data.msg[e_id].to(device),
        )

        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])

        loss = criterion(pos_out, torch.ones_like(pos_out))
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        # Update memory and neighbor loader with ground-truth state.
        model['memory'].update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)

        loss.backward()
        optimizer.step()
        model['memory'].detach()
        total_loss += float(loss) * batch.num_events

    return total_loss / train_data.num_events


@torch.no_grad()
def test(model, data, mask, batch_size, neighbor_loader, neg_sampler, split_mode, assoc, metric, evaluator, device, debug=True, adv_edges_param=None):
    r"""
    Evaluated the dynamic link prediction
    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges

    Parameters:
        loader: an object containing positive attributes of the positive edges of the evaluation set
        neg_sampler: an object that gives the negative edges corresponding to each positive edge
        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
    Returns:
        perf_metric: the result of the performance evaluaiton
    """
    model['memory'].eval()
    model['gnn'].eval()
    model['link_pred'].eval()

    full_data = deepcopy(data)
    split_data = full_data[mask]
    loader = TemporalDataLoader(split_data, batch_size=batch_size)

    perf_list = []
    if adv_edges_param:
        logging.info(f"attacking at {adv_edges_param[1][-1]}")
    
    for pos_batch in loader:
        pos_src, pos_dst, pos_t, pos_msg = (
            pos_batch.src,
            pos_batch.dst,
            pos_batch.t,
            pos_batch.msg,
        )

        if adv_edges_param:
            if debug:
                cosine = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
                pdist = torch.nn.PairwiseDistance(p=2, eps=0.0)
                msg_cosine_dsts = []
                msg_euclidean_dsts = []
            attack_type, budget , attack_params = adv_edges_param
            if attack_type=="negatt-fgsm":
                (adv_edges, time_sd,  epsilon, att_pos) = attack_params
                edge_class=0.0
                
            elif attack_type=="negatt": 
                (adv_edges, time_sd, att_pos) = attack_params
            
            n = round(len(pos_src) * budget)

            if att_pos=="front":
                if attack_type.startswith("negatt"):
                    # adv attack with -ve edges
                    
                    if n>0:
                        adv_sources, adv_dests, adv_t, adv_msgs = add_negative_neighbor(adv_edges, time_sd, n, pos_msg.shape[-1], pos_t[0], 
                                                                                        device=device)
                        # adv_sources, adv_dests, adv_t, adv_msgs = add_negative_neighbor_wrt_pos(adv_edges, pos_batch, time_sd, n, pos_msg.shape[-1], pos_t[0], 
                        #                                                                 device=device)

                        # adv_sources_b = adv_sources.unsqueeze(1).expand(-1, len(pos_src))
                        # adv_dests_b = adv_dests.unsqueeze(1).expand(-1, len(pos_dst))
                        # count_contained = torch.sum((pos_src.unsqueeze(0)==adv_sources_b) & (pos_dst.unsqueeze(0)==adv_dests_b), dim=1)
                        # assert count_contained.sum()==0
                        
                        adv_sources = adv_sources.to(device)
                        # fix source
                        if False:
                            adv_sources = torch.full(adv_sources.size(), pos_src[np.random.randint(pos_src.size(0))]).to(device)
                        adv_dests = adv_dests.to(device)
                        adv_msgs = adv_msgs.to(device)
                        adv_t = adv_t.to(device)

                        if attack_type.endswith("fgsm"):
                            old_adv_msgs = deepcopy(adv_msgs)
                            with torch.enable_grad():
                                adv_msgs = fgsm(deepcopy(model), deepcopy(full_data), deepcopy(neighbor_loader), deepcopy(assoc),
                                    adv_sources, adv_dests, adv_t, adv_msgs, epsilon, edge_class, device)
                            if debug:
                                msg_cosine_dsts.append(cosine(old_adv_msgs, adv_msgs).mean().item())
                                msg_euclidean_dsts.append(pdist(old_adv_msgs, adv_msgs).mean().item())

                        # add the adv edges to self.data
                        cur_idx = neighbor_loader.cur_e_id
                        full_data.src = torch.cat([full_data.src[:cur_idx], adv_sources, full_data.src[cur_idx:]])
                        full_data.dst = torch.cat([full_data.dst[:cur_idx], adv_dests, full_data.dst[cur_idx:]])
                        full_data.t = torch.cat([full_data.t[:cur_idx], adv_t, full_data.t[cur_idx:]])
                        full_data.msg = torch.cat([full_data.msg[:cur_idx], adv_msgs, full_data.msg[cur_idx:]])

                        # add the adversarial edges to the bank
                        batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(adv_sources, adv_dests, adv_t, adv_msgs)}
                        adv_edges = adv_edges | batch_pos

                        # # add positives of current batch to the bank
                        batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(pos_src, pos_dst, pos_t, pos_msg)}
                        adv_edges = adv_edges | batch_pos

                        # Update memory and neighbor loader with ground-truth state.
                        model['memory'].update_state(adv_sources, adv_dests, adv_t, adv_msgs)
                        neighbor_loader.insert(adv_sources, adv_dests)

                elif attack_type=="fgsm" and n > 0:
                    n = round(len(pos_src) * budget)
                    (epsilon, att_pos) = attack_params
                    edge_class=1.0
                    attack_idx = torch.randperm(pos_src.size(0))[:n]
                    with torch.enable_grad():
                        adv_pos_msg = fgsm(deepcopy(model), deepcopy(full_data), deepcopy(model.neighbor_loader), deepcopy(assoc),
                            deepcopy(pos_src[attack_idx]), deepcopy(pos_dst[attack_idx]), deepcopy(pos_t[attack_idx]), 
                            deepcopy(pos_msg[attack_idx]), epsilon, edge_class,device)
                    
                    if debug:
                        msg_cosine_dsts.append(cosine(pos_msg[attack_idx], adv_pos_msg).mean().item())
                        msg_euclidean_dsts.append(pdist(pos_msg[attack_idx], adv_pos_msg).mean().item())
                    
                    pos_msg[attack_idx] = adv_pos_msg
            
                elif attack_type=="prbcd" :
                    pass

        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)

        for idx, neg_batch in enumerate(neg_batch_list):
            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
            dst = torch.tensor(
                np.concatenate(
                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
                    axis=0,
                ),
                device=device,
            )

            n_id = torch.cat([src, dst]).unique()
            n_id, edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=device)

            # Get updated memory of all nodes involved in the computation.
            z, last_update = model['memory'](n_id)

            z = model['gnn'](
                z,
                last_update,
                edge_index,
                full_data.t[e_id].to(device),
                full_data.msg[e_id].to(device),
            )
            
            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])

            # compute MRR
            input_dict = {
                "y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
                "y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
                "eval_metric": [metric],
            }
            perf_list.append(evaluator.eval(input_dict)[metric])
        
        if adv_edges_param and att_pos=="back":
            
            adv_sources, adv_dests, adv_t, adv_msgs = add_negative_neighbor(adv_edges, time_sd, n, pos_msg.shape[-1], pos_t[-1], device=device)

            if False:
                # extend the positive batch
                pos_src = torch.cat([pos_src, adv_sources]).to(device)
                pos_dst = torch.cat([pos_dst, adv_dests]).to(device)
                pos_msg = torch.cat([pos_msg, adv_msgs]).to(device)
                pos_t = torch.cat([pos_t, adv_t]).to(device)
            else:
                # replace parts of positive batch
                idx = random.sample(list(range(len(pos_src))), n)
                pos_src[idx] = adv_sources 
                pos_dst[idx] = adv_dests
                pos_msg[idx] = adv_msgs
                pos_t[idx] = adv_t
                logging.info(f"len(pos_src), {len(pos_src)}")

            # add the adv edges to self.data
            cur_idx = neighbor_loader.cur_e_id
            full_data.src = torch.cat([full_data.src[:cur_idx], adv_sources, full_data.src[cur_idx:]])
            full_data.dst = torch.cat([full_data.dst[:cur_idx], adv_dests, full_data.dst[cur_idx:]])
            full_data.t = torch.cat([full_data.t[:cur_idx], adv_t, full_data.t[cur_idx:]])
            full_data.msg = torch.cat([full_data.msg[:cur_idx], adv_msgs, full_data.msg[cur_idx:]])

            # add current batch edges, and the adversarial edgesexit to the bank
            batch_pos = {(s.reshape(-1),d.reshape(-1),t.reshape(-1)):m.reshape(1,-1) for (s,d,t,m) in zip(pos_src, pos_dst, pos_t, pos_msg)}
            adv_edges = adv_edges | batch_pos

        # Update memory and neighbor loader with ground-truth state.
        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
        neighbor_loader.insert(pos_src, pos_dst)

    if debug:
        if adv_edges_param and "fgsm" in adv_edges_param[0]:
            mean_cosine_sim = np.mean(msg_cosine_dsts)
            mean_euclidean_dist = np.mean(msg_euclidean_dsts)
            logging.info(f"mean cosine similarity = {mean_cosine_sim}")
            logging.info(f"mean euclidean distance = {mean_euclidean_dist}")
    perf_metrics = float(torch.tensor(perf_list).mean())

    return perf_metrics



#TGN as a torch model

class TGN(torch.nn.Module):
    def __init__(self, model, neighbor_loader):
        self.memory = model["memory"]
        self.gnn = model["gnn"]
        self.link_pred = model["link_pred"]
        self.neighbor_loader = neighbor_loader

    def forward(self, data, edge_index, edge_weight):
        assert  torch.all(torch.logical_or(edge_weight == 0, edge_weight == 1))
        src = edge_index[0]
        dest = edge_index[1]
        t = data.t
        msg = data.msg
        pass

    
    def train(self):
        self.memory.train()
        self.gnn.train()
        self.link_pred.train()
    
    def eval(self):
        self.memory.eval()
        self.gnn.eval()
        self.link_pred.eval()
        