import einops
from torch.nn import functional as F
import torch.nn as nn
from sympy import prod
import torch
import math


def positionalencoding2d(d_model, height, width):
    """
    :param d_model: dimension of the model
    :param height: height of the positions
    :param width: width of the positions
    :return: d_model*height*width position matrix
    """
    if d_model % 4 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dimension (got dim={:d})".format(d_model))
    pe = torch.zeros(d_model, height, width)
    # Each dimension use half of d_model
    d_model = int(d_model / 2)
    div_term = torch.exp(torch.arange(0., d_model, 2) *
                         -(math.log(10000.0) / d_model))
    pos_w = torch.arange(0., width).unsqueeze(1)
    pos_h = torch.arange(0., height).unsqueeze(1)
    pe[0:d_model:2, :, :] = torch.sin(
        pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[1:d_model:2, :, :] = torch.cos(
        pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[d_model::2, :, :] = torch.sin(
        pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
    pe[d_model + 1::2, :,
        :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

    return pe


@torch.no_grad()
def board_accuracy(
    logits: torch.Tensor,  # (B, 81, 9) or (N, B, 81, 9)
    x: torch.Tensor,       # (B, 81) or (N, B, 81)in {0..9}
    y: torch.Tensor,       # (B, 81) or (N, B, 81) in {0..8}
    filled: torch.Tensor
) -> float:
    """
    Accuracy for blank cells
    """
    pred = logits.argmax(dim=-1)   # (B, 81) or (N, B) in {0..8}
    tgt = y.long()           # (B, 81) or (N, 81) in {0..8}

    correct = (pred == tgt)

    # Only consider blank cells
    solved_mask_blank = (filled | correct).all(dim=1)  # (B,)
    board_acc = solved_mask_blank.float().mean().item()

    return board_acc


@torch.no_grad()
def digit_accuracy(
    logits: torch.Tensor,  # (B, 81, 9)
    target: torch.Tensor,  # (B, 81) in {1..9}
):
    pred = logits.argmax(dim=-1)   # (B, 81)
    mask = (target != -100)  # (B, 81)
    correct = (pred[mask] == target[mask]).sum().item()
    total = mask.sum().item()
    acc = correct / total if total > 0 else 0.0
    return acc

# From AKOrN
# https://github.com/autonomousvision/akorn/blob/main/source/layers/kutils.py


def reshape(x: torch.Tensor, n: int):
    if x.ndim == 3:  # x.shape = ([B, T, C ])
        return x.transpose(1, 2).unflatten(1, (-1, n))
    else:  # x.shape = ([B, C, ..., ])
        return x.unflatten(1, (-1, n))


def reshape_back(x):
    if x.ndim == 4:  # Tokens
        return x.flatten(1, 2).transpose(1, 2)
    else:
        return x.flatten(1, 2)


def _l2normalize(x):
    return torch.nn.functional.normalize(x, dim=2)


def norm(n, x, dim=2, keepdim=True):
    return torch.linalg.norm(reshape(x, n), dim=dim, keepdim=keepdim)


def normalize(x: torch.Tensor, n):
    x = reshape(x, n)
    x = _l2normalize(x)
    x = reshape_back(x)
    return x
