import numpy as np
from scipy.stats import gaussian_kde
import torch
import torch.distributed as dist
from torch.nn import Dropout, Module

import math

from src.utils.log_utils import _debug_values

MOCK_SIZE = 128 * 10


def get_dist(data: np.array):
    kde = gaussian_kde(data)
    grid = np.linspace(data.min(), data.max(), 100)
    p = kde(grid)
    p = p / p.sum()
    return p


class PositionalEncoding(Module):

    def __init__(self, d_model, dropout=0, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(100.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def get_pe(self):
        return self.pe

    def forward(self, x):
        x = x + self.pe[: x.size(1), :]
        return self.dropout(x)


def get_1d_sincos_pos_embed(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    # pos = pos.reshape(-1)   # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def apply_masks(
    batchs: torch.Tensor,
    masks: list,
    feature_space: bool = False,
    cardinalities: list = [],
):
    """
    @args:

           B: batch size, D: number of features, card=1 if numerical features
           and cardinality if numerical features.
           tensor of dimension [B, D, E] if feature_space=False.
           B: batch size, D: number of features, E hidden dim.
    mask: list of binary vectors of dimension [B, D].
          0: mask feature, 1: visible feature.
    feature_space: whether processing is done in the feature space
                   or in the latent space.
    cardinalities: List of tuples containing (feature_idx, cardinality)

    @returns:
        if not feature_space:
            returns a list of mask_num_preds tensors of dimension
            (B,D,E)
        else:
            returns the original batchs zeroed out when relevant
            and with an extra dimension stating whether the feature is masked
            (feature_value, 0) if unmasked, (0,1) otherwise.
    """
    all_batches = []
    if not feature_space:
        masks = masks.transpose(0, 1)

        columns_to_keep = torch.ones(masks.shape[-1])
        card_idx = 0
        orig_idx = 0
        idx = 0
        while idx < masks.shape[-1]:
            if orig_idx in [card[0] for card in cardinalities]:
                for _ in range(
                    cardinalities[card_idx][1] - 1,
                ):
                    idx += 1
                    columns_to_keep[idx] = 0
                card_idx += 1
            idx += 1
            orig_idx += 1

        masks = masks[:, :, columns_to_keep == 1]

        for mask in masks:
            all_batches.append((mask.unsqueeze(dim=-1) * batchs))

    else:
        for mask in masks:
            masked_batch = batchs.squeeze() * mask

            masked_batch = masked_batch.unsqueeze(dim=-1)
            mask = mask.unsqueeze(dim=-1)
            masked_batch = torch.concat((masked_batch, mask), dim=-1)
            all_batches.append(masked_batch)

    return all_batches


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.0))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def init_weights(m, init_type="trunc_normal"):

    if isinstance(m, torch.nn.LayerNorm):
        torch.nn.init.constant_(m.bias, 0)
        torch.nn.init.constant_(m.weight, 1.0)

    if init_type == "trunc_normal":
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.Embedding):
            trunc_normal_(m.weight, std=0.02)
    elif init_type == "xavier_uniform":
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.Embedding):
            torch.nn.init.xavier_uniform_(m.weight)
    elif init_type == "xavier_normal":
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
    elif init_type == "kaiming_normal":
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.Embedding):
            torch.nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
    elif init_type == "normal":
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.Embedding):
            torch.nn.init.normal_(m.weight, std=0.02)


def apply_masks_from_idx(x, masks):
    """
    :param x: tensor of shape [B (batch-size), N (num-feature), D (feature-dim)]
    :param masks: list of tensors containing indices of feature rep in [N] to keep
    """
    all_x = []
    for m in masks:
        mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
        all_x += [torch.gather(x, dim=1, index=mask_keep)]
    return torch.cat(all_x, dim=0)


def get_idx_from_mask(mask):
    """
    From binary mask index, get non zero index.
    """
    idx_masks = []
    for idx, m in enumerate(mask):
        nonzero_indices = torch.nonzero(m)
        _, column_indices = nonzero_indices.t()
        nonzero_columns = torch.split_with_sizes(
            column_indices,
            torch.unique(nonzero_indices[:, 0], return_counts=True)[1].tolist(),
        )
        idx_masks.append(torch.stack(nonzero_columns))
    return idx_masks


class AllReduce(torch.autograd.Function):
    """
    from https://github.com/facebookresearch/ijepa
    """

    @staticmethod
    def forward(ctx, x):
        if (
            dist.is_available()
            and dist.is_initialized()
            and (dist.get_world_size() > 1)
        ):
            x = x.contiguous() / dist.get_world_size()
            dist.all_reduce(x)
        return x

    @staticmethod
    def backward(ctx, grads):
        return grads


def get_distributed_dataloader(
    batchsize, dataset, distributed_args, data_loader_nprocs, mask_collator, pin_memory
):

    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=distributed_args["world_size"],
        rank=distributed_args["rank"],
    )

    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batchsize,
        num_workers=data_loader_nprocs,
        collate_fn=mask_collator,
        pin_memory=pin_memory,
        sampler=sampler,
        drop_last=False,
        persistent_workers=True,
    )
    # (https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110)

    print("Successfully loaded distributed batch dataloader.")

    return dataloader
