import torch
import torch.nn as nn
import torch.nn.functional as F

from deephfts.mats import dis_matrix, difference_matrix

def sblock_loss(y:torch.Tensor, y_pred: torch.Tensor, losstype: str) -> torch.Tensor:
    """S-Block Loss function

    Args:
        y (torch.Tensor): argmin of the MSE loss from the R-Matrix. 
        y_pred (torch.Tensor): Predictions of y. 
        losstype (str): Error type for loss of S-Block. 

    Returns:
        torch.Tensor: Loss fo the s-block. 
    """
    if(losstype == "mse"):
        criterion = nn.MSELoss()
        loss = criterion(y, y_pred)
        # loss = criterion(preds, y)
    
    if(losstype == "NLL"):
        softmax_spreds_flatten = torch.flatten(y_pred)
        C = torch.argmax(y.float())
        NLL = nn.NLLLoss()
        loss = NLL(softmax_spreds_flatten, C)
        return loss

    return loss

def sblock_DWCL(
    s_preds : torch.Tensor,
    r_preds : torch.Tensor,
    y : torch.Tensor
    ) -> torch.Tensor:
    """Method to compute Distance-Weighted Contrastive Loss

    Args:
        s_preds (torch.Tensor): Predictions from the S-Block. 
        r_preds (torch.Tensor): Predictions from the R-Block.
        y (torch.Tensor): Next time step. 

    Returns:
        torch.Tensor: distanced-weighted contrastive loss for S-Block
    """
    pass
def rblock_loss(
    pred_matrix : torch.Tensor, 
    y : torch.Tensor, 
    losstype : str,
    coordinates : list,
    alpha : int=0.2)->torch.Tensor:
    """Generates a loss for a given R-block in the R-matrix. 

    Args:
        pred_matrix (torch.Tensor): Matrix with a list of predictions. 
        y (torch.Tensor): Ground truth values. 
        losstype (str): Currently supports MSE loss - reconstruction or forecasting. 
        coordinates (list): List of current coordinates. Currently supports 2 dimensions. 

    Returns:
        torch.Tensor: Loss for a given rblock. 
    """

    x_coord, y_coord = coordinates
    pred_k = pred_matrix[x_coord, y_coord, :, :]
    if(losstype == "forecasting" or losstype == "reconstruction"):
        criterion = nn.MSELoss()
        loss_k = criterion(pred_k, y)
        # loss_k = min(loss_k, 1)
    elif(losstype == "sMAPE"):
        loss_k = torch.mean(2*(y - pred_k).abs() / (y.abs() + pred_k.abs()))

    dist_block = dis_matrix(x_shape=pred_matrix.shape[0], y_shape=pred_matrix.shape[1], x_coord=x_coord, y_coord=y_coord)
    pred_block = torch.ones(pred_matrix.shape) * pred_k
    sim_block = -torch.log(torch.exp(pred_block) / torch.exp(pred_matrix)) # get contrastive loss. 
    sim_block = torch.mean(sim_block, dim=-1) # average across timesteps. 
    dist_block = dis_matrix(x_shape=pred_matrix.shape[0], y_shape=pred_matrix.shape[1], x_coord=x_coord, y_coord=y_coord)

    sim_block = sim_block.reshape(dist_block.shape)

    dist_block = dis_matrix(x_shape=pred_matrix.shape[0], y_shape=pred_matrix.shape[1], x_coord=x_coord, y_coord=y_coord)
    loss_block = sim_block * dist_block
    loss_block = torch.mean(loss_block)

    loss = loss_k + loss_block

    return loss, loss_k, loss_block
        
