import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from utils import subsample_src, subsample_dst, preprocess
import copy
import tqdm


class Validator:
    def __init__(
        self,
        val_loader,
        train_loader,
        test_loader,
        device,
    ):
        self.__val_loader = val_loader
        self.__train_loader = train_loader
        self.__test_loader = test_loader
        self.__device = device

    @torch.no_grad()
    def val_loss(
        self,
        embd_to_score_src,
        embd_to_score_dst,
        memory,
        gnn,
        feats_model,
        embd_to_h0,
        neighbor_loader,
        min_dst_idx,
        max_dst_idx,
        all_src,
        all_dst,
        rand_all_src,
        rand_smpsrc,
        tmap_src,
        n_sampled_src,
        rand_all_dst,
        n_sampled_dst,
        rand_smpdst,
        tmap,
        assoc,
        data,
        eps,
        val_data,
        loader,
    ):
        memory.eval()
        gnn.eval()
        embd_to_score_src.eval()
        embd_to_score_dst.eval()
        feats_model.eval()
        embd_to_h0.eval()

        total_loss = 0
        total_loss_feats = 0.0
        total_norm_feats = 0.0

        loader_ = None
        if loader == "train":
            loader_ = self.__train_loader
        elif loader == "val":
            loader_ = self.__val_loader
        else:
            loader_ = self.__test_loader

        for batch in tqdm.tqdm(loader_):
            batch = batch.to(self.__device)

            loss_unnormalized = torch.tensor([0.0])
            norm = torch.tensor(1.0)

            src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

            neg_dst = torch.randint(
                min_dst_idx,
                max_dst_idx + 1,
                (src.size(0),),
                dtype=torch.long,
                device=self.__device,
            )

            # Sample a subset of origins, including the positive ones
            sampled_src, idx_pos_src = subsample_src(
                src,
                all_src,
                rand_all_src,
                rand_smpsrc,
                tmap_src,
                n_sampled_src,
                self.__device,
                sample_all=False,
            )

            # Sample a subset of destinations, including the positive ones
            sampled_dst, idx_pos_dst = subsample_dst(
                pos_dst,
                rand_all_dst,
                n_sampled_dst,
                rand_smpdst,
                all_dst,
                min_dst_idx,
                tmap,
                self.__device,
                sample_all=False,
            )
            self.__idx_pos_dst = idx_pos_dst

            n_id = torch.cat(
                [src, pos_dst, neg_dst, sampled_src.unique(), sampled_dst.unique()]
            ).unique()
            n_id, edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=self.__device)

            # Get updated memory of all nodes involved in the computation.
            z, last_update = memory(n_id)
            # compute embeddings
            z = gnn(
                z,
                last_update,
                edge_index,
                data.t[e_id].to(self.__device),
                data.msg[e_id].to(self.__device),
            )

            # Origins
            scores_src = embd_to_score_src(z[assoc[sampled_src]])
            # transform scores into log probabilities
            log_probs_src = scores_src - scores_src.logsumexp(dim=0).unsqueeze(1)
            # get the log probability of the positive targets only
            log_probs_pos_src = torch.gather(
                log_probs_src, dim=0, index=idx_pos_src.unsqueeze(dim=1)
            )
            loss_src = -1 * log_probs_pos_src.sum()
            loss_unnormalized += loss_src.cpu()
            norm += np.prod(log_probs_pos_src.shape)

            # Destinations
            scores_dst = embd_to_score_dst(z[assoc[src]], z[assoc[sampled_dst]])
            self.__scores_dst = scores_dst
            # transform scores into log probabilities
            log_probs_dst = scores_dst - scores_dst.logsumexp(dim=1).unsqueeze(1)
            self.__log_probs_dst = log_probs_dst
            # get the log probability of the positive targets only
            log_probs_pos_dst = torch.gather(
                log_probs_dst, dim=1, index=idx_pos_dst.unsqueeze(dim=1)
            )
            self.__log_probs_pos_dst = log_probs_pos_dst
            loss_dst = -1 * log_probs_pos_dst.sum()
            loss_unnormalized += loss_dst.cpu()
            norm += np.prod(log_probs_pos_dst.shape)

            # add noise
            with torch.no_grad():
                noise = torch.randn_like(msg) * eps
                noise[:, 0] = 0
                noise[:, feats_model.categorical_idx] = 0
                msg = msg + noise
            msgp = preprocess(msg, self.__device)
            inputs = msgp[:, :-1]
            targets = msgp[:, 1:]
            h_0 = embd_to_h0(z[assoc[src]], z[assoc[pos_dst]]).unsqueeze(0)
            x = feats_model(inputs, h_0)
            x = torch.nan_to_num(x, nan=0.01)

            loss_feats = -sum(
                c.ptdist(x)
                .log_prob(torch.nan_to_num(targets[:, c.mask, -1], nan=0.01))
                .sum()
                if c.distrib
                is torch.distributions.mixture_same_family.MixtureSameFamily
                else c.ptdist(x).log_prob(targets[:, c.mask]).sum()
                for c in feats_model.columns
            )

            loss_unnormalized += loss_feats.cpu()
            norm += np.prod(targets.shape)
            total_loss_feats += loss_feats
            total_norm_feats += np.prod(targets.shape)

            # using total losses (sum) and normalizing here
            loss = loss_unnormalized / norm

            best_loss = None

            # FIXME
            best_model = None

            # Update memory and neighbor loader with ground-truth state.
            memory.update_state(src.cuda(), pos_dst.cuda(), t.cuda(), msg.cuda())
            neighbor_loader.insert(src, pos_dst)

            total_loss += float(loss) * batch.num_events

        return (
            total_loss / val_data.num_events,
            total_loss_feats / total_norm_feats,
            best_model,
            best_loss,
        )

    @torch.no_grad()
    def link_pred_from_emb(
        self,
        min_dst_idx,
        max_dst_idx,
        memory,
        gnn,
        assoc,
        data,
        neighbor_loader,
        dataset,
    ):

        loader = None
        if dataset == "train":
            loader = self.__train_loader
        elif dataset == "val":
            loader = self.__val_loader
        else:
            loader = self.__test_loader

        aps, aucs = [], []
        for batch in tqdm.tqdm(loader):
            batch = batch.to(self.__device)
            src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

            neg_dst = torch.randint(
                min_dst_idx,
                max_dst_idx + 1,
                (src.size(0),),
                dtype=torch.long,
                device=self.__device,
            )

            n_id = torch.cat([src, pos_dst, neg_dst]).unique()
            n_id, edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=self.__device)

            z, last_update = memory(n_id)
            z = gnn(
                z,
                last_update,
                edge_index,
                data.t[e_id].to(self.__device),
                data.msg[e_id].to(self.__device),
            )

            with torch.no_grad():
                idx_neg_dst = (self.__idx_pos_dst + 1) % self.__scores_dst.shape[1]

                pos_out = self.__log_probs_pos_dst.exp()
                neg_out = torch.gather(
                    self.__log_probs_dst,
                    dim=1,
                    index=idx_neg_dst.unsqueeze(dim=1),
                ).exp()
                y_pred = torch.cat([pos_out, neg_out], dim=0).cpu()
                y_true = torch.cat(
                    [torch.ones(pos_out.size(0)), torch.zeros(neg_out.size(0))], dim=0
                )

                aps.append(average_precision_score(y_true, y_pred))
                aucs.append(roc_auc_score(y_true, y_pred))

            memory.update_state(src, pos_dst, t, msg)
            neighbor_loader.insert(src, pos_dst)

        return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())
