import math
from typing import Tuple

import random
import numpy as np
import torch
import torch.nn.functional as F
import torch_sparse
from torch import Tensor
from torch_geometric.typing import Adj
from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_geometric.utils import dropout_adj
from tqdm import tqdm


class Dict(dict):

    def __getattr__(self, key):
        return self.get(key)

    def __setattr__(self, key, value):
        self[key] = value


def setup_seed(seed):
    if seed == -1:
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def dropout_edge(edge_index: Adj, p: float, training: bool = True):
    if not training or p == 0.:
        return edge_index

    if isinstance(edge_index, SparseTensor):
        if edge_index.storage.value() is not None:
            value = F.dropout(edge_index.storage.value(), p=p)
            edge_index = edge_index.set_value(value, layout='coo')
        else:
            mask = torch.rand(edge_index.nnz(), device=edge_index.storage.row().device) > p
            edge_index = edge_index.masked_select_nnz(mask, layout='coo')
    else:
        edge_index, edge_attr = dropout_adj(edge_index, p=p)

    return edge_index


def directed_norm(adj):
    """
    Applies the normalization for directed graphs:
        \mathbf{D}_{out}^{-1/2} \mathbf{A} \mathbf{D}_{in}^{-1/2}.
    """
    in_deg = torch_sparse.sum(adj, dim=0)
    in_deg_inv_sqrt = in_deg.pow_(-0.5)
    in_deg_inv_sqrt.masked_fill_(in_deg_inv_sqrt == float("inf"), 0.0)

    out_deg = torch_sparse.sum(adj, dim=1)
    out_deg_inv_sqrt = out_deg.pow_(-0.5)
    out_deg_inv_sqrt.masked_fill_(out_deg_inv_sqrt == float("inf"), 0.0)

    adj = torch_sparse.mul(adj, out_deg_inv_sqrt.view(-1, 1))
    adj = torch_sparse.mul(adj, in_deg_inv_sqrt.view(1, -1))
    return adj


def adj_norm(adj, norm='rw', add_self_loop=True):
    assert norm in ['sym', 'rw', 'dir']

    if add_self_loop:
        adj = torch_sparse.fill_diag(adj, 1.0)
    deg = adj.sum(dim=1)

    if norm == 'sym':
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1)
    elif norm == 'rw':
        deg_inv = deg.pow_(-1)
        deg_inv.masked_fill_(deg_inv == float('inf'), 0)
        adj = deg_inv.view(-1, 1) * adj
    elif norm == 'dir':
        """
            Applies the normalization for directed graphs:
                \mathbf{D}_{out}^{-1/2} \mathbf{A} \mathbf{D}_{in}^{-1/2}.
            """
        in_deg = torch_sparse.sum(adj, dim=0)
        in_deg_inv_sqrt = in_deg.pow_(-0.5)
        in_deg_inv_sqrt.masked_fill_(in_deg_inv_sqrt == float("inf"), 0.0)

        out_deg = torch_sparse.sum(adj, dim=1)
        out_deg_inv_sqrt = out_deg.pow_(-0.5)
        out_deg_inv_sqrt.masked_fill_(out_deg_inv_sqrt == float("inf"), 0.0)

        adj = out_deg_inv_sqrt.view(-1, 1) * adj * in_deg_inv_sqrt.view(1, -1)
    else:
        raise NotImplementedError

    return adj


def group_values(values: Tensor, num_groups=5, group_by='sample', log_scale=False):
    assert group_by in ['sample', 'domain']
    assert values.size(0) == values.view(-1).size(0)
    if group_by == 'sample':
        values_, indices = values.sort()
        group_idx = torch.arange(num_groups).repeat_interleave(int(values.size(0) / num_groups))
        group_idx = torch.cat(
            [group_idx, torch.full((values.size(0) % num_groups,), num_groups - 1)])
        group_val = scatter(values_, group_idx, reduce='mean')
    elif group_by == 'domain':
        group_idx = torch.zeros_like(values, dtype=torch.long)
        if log_scale:
            values = torch.log10(values)
        min_val, max_val = values.min(), values.max()
        step = (max_val - min_val) / num_groups
        group_val = []
        for i in range(num_groups):
            mask = (values > min_val + i * step) & (values <= min_val + (i + 1) * step)
            if i == 0:
                mask = mask | (values == min_val)
            group_idx[mask] = i
            group_val.append((min_val + (2 * i + 1) / 2 * step))
        group_val = torch.tensor(group_val)
        if log_scale:
            group_val = torch.pow(10, group_val)
        indices = torch.arange(values.size(0))
    else:
        raise NotImplementedError
    return group_val, group_idx, indices


def pred_fn(y_hat, y) -> Tuple[Tensor, Tensor]:
    if y_hat.shape[1] == 1:  # binary, auc_roc
        pred = y_hat
    elif y.dim() == 1:  # multi-class
        pred = y_hat.argmax(dim=-1)
    else:  # multi-label
        pred = (y_hat > 0).float()
    return pred, y


def loss_fn(y_hat, y) -> Tensor:
    if y_hat.shape[1] == 1:  # binary
        return F.binary_cross_entropy_with_logits(y_hat.squeeze(), y.float())
    elif y.dim() > 1:  # multi-label
        return F.binary_cross_entropy_with_logits(y_hat, y)
    else:  # multi-class
        loss = F.cross_entropy(y_hat, y)

        # y = F.cross_entropy(y_hat, y, reduction='none')
        # epsilon = 1 - math.log(2)
        # y = torch.log(epsilon + y) - math.log(epsilon)
        # loss = y.mean()

        return loss


def distance_matrix(x, y=None, p=2):
    # Returns a pair wise distance between all elements in matrix x and y
    y = x if y is None else y

    n = x.size(0)
    m = y.size(0)
    d = x.size(1)

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    dist = torch.linalg.vector_norm(x - y, p, 2)
    return dist


class KNN:
    def __init__(self, train_pts, k=1, p=2):
        self.train_pts = train_pts
        self.k = k
        self.p = p

    def predict(self, test_pts, batch_size=1024):
        top_vals, top_idx = [], []
        for batch_test in tqdm(test_pts.tensor_split(test_pts.size(0) // batch_size)):
            dist = distance_matrix(batch_test, self.train_pts, self.p)
            vals, idx = dist.topk(self.k, largest=False)
            top_vals.append(vals)
            top_idx.append(idx)
        top_vals, top_idx = torch.cat(top_vals), torch.cat(top_idx)
        return top_vals.cpu(), top_idx.cpu()


def padding_test_to_full(test_pts, num_full_pts, test_mask):
    out = torch.zeros(num_full_pts, dtype=test_pts.dtype, device=test_pts.device)
    out[test_mask] = test_pts
    return out


def mask_to_index(mask):
    if mask.dtype == torch.bool:
        return mask.nonzero(as_tuple=False).view(-1)
    else:
        return mask


def index_to_mask(index: Tensor, size: int = None) -> Tensor:
    if index.dtype == torch.bool:
        return index
    else:
        index = index.view(-1)
        size = int(index.max()) + 1 if size is None else size
        mask = index.new_zeros(size, dtype=torch.bool)
        mask[index] = True
        return mask
