import torch
from torch.functional import Tensor
import torch.nn as nn

def coord_matrix(
    x_shape: int, 
    y_shape: int) -> torch.Tensor():
    """Method generates coordinate matrix for a block with a width of x, height of y. 

    Args:
        x_shape (int): Block width. 
        y_shape (int): Block height. 
    Returns:
        torch.Tensor: Coordinate matrix
    """
    m = list()
    for i in range(x_shape):
        r = list()
        for j in range(y_shape):
            r.append([i, j])
        m.append(r)
    return torch.Tensor(m)


def dis_matrix(
    x_shape: int, 
    y_shape: int,
    x_coord: int, 
    y_coord: int) -> torch.Tensor():
    """Method generates distance matrix for a block with a width of x, height of y. 
    Given coordinates x_coord, y_coord. 

    Args:
        x_shape (int): Block width. 
        y_shape (int): Block height. 
        x_coord (int): X-coordinate for current block. 
        y_coord (int): Y-coordinate for current block. 
    Returns:
        torch.Tensor: Distance matrix
    """
    m = list()
    for i in range(y_shape):
        r = list()
        for j in range(x_shape):
            r.append([i, j])
        m.append(r)
    m2 = list()
    for i in range(y_shape):
        r = list()
        for j in range(x_shape):
            r.append([x_coord, y_coord])
        m2.append(r)
    m = torch.Tensor(m)
    m2 = torch.Tensor(m2)
    pdist = nn.PairwiseDistance(p=2)
    return pdist(m, m2)

def block_argmax(block_preds: torch.Tensor) -> torch.Tensor():
    """Generate argmax of a block. 

    Args:
        block_preds (torch.Tensor): Block of predictions.

    Returns:
        torch.Tensor: argmax'd block of predictions. 
    """
    block_max = torch.max(block_preds)
    block_rounded = torch.where(block_preds == block_max, 1, 0)
    block_rounded = block_rounded.float()
    return block_rounded

def block_argmin(block_preds: torch.Tensor) -> torch.Tensor():
    """Generate argmin of a block. 

    Args:
        block_preds (torch.Tensor): Block of predictions.

    Returns:
        torch.Tensor: argmin'd block of predictions. 
    """
    block_min = torch.min(block_preds)
    block_rounded = torch.where(block_preds == block_min, 1, 0)
    block_rounded = block_rounded.float()
    return block_rounded



def difference_matrix(y, rmatrix_preds):
    """Returns an argmin difference matrix between y and the R matrix predictions.

    Args:
        y (torch.Tensor): output. 
        rmatrix_preds (torch.Tensor): Predictions

    Returns:
        torch.Tensor: Difference matrix
    """
    R_diff = (y - rmatrix_preds)**2
    R_diff = torch.mean(R_diff, dim=-1)
    R_diff = R_diff.permute(-1, 0, 1)
    R_diff = torch.unsqueeze(R_diff, dim = 0)
    #R_diff_rounded = block_argmax(R_diff)
    R_diff_rounded = block_argmin(R_diff)
    
    return R_diff_rounded