import math
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score, roc_auc_score
from functools import partial
import torchmetrics.functional as tm_f
import pytorch_lightning as pl
import torch.nn as nn

import numpy as np

def _student_t_map(mu, sigma, nu):
    sigma = F.softplus(sigma)
    nu = 2.0 + F.softplus(nu)
    return mu.squeeze(axis=-1), sigma.squeeze(axis=-1), nu.squeeze(axis=-1)

def student_t_loss(outs, y):
    mu, sigma, nu = outs[..., 0], outs[..., 1], outs[..., 2]
    mu, sigma, nu = _student_t_map(mu, sigma, nu)
    y = y.squeeze(axis=-1)

    nup1_half = (nu + 1.0) / 2.0
    part1 = 1.0 / nu * torch.square((y - mu) / sigma)
    Z = (
        torch.lgamma(nup1_half)
        - torch.lgamma(nu / 2.0)
        - 0.5 * torch.log(math.pi * nu)
        - torch.log(sigma)
    )

    ll = Z - nup1_half * torch.log1p(part1)
    return -ll.mean()

def gaussian_ll_loss(outs, y):
    mu, sigma = outs[..., 0], outs[..., 1]
    y = y.squeeze(axis=-1)
    sigma = F.softplus(sigma)
    ll = -1.0 * (
        torch.log(sigma)
        + 0.5 * math.log(2 * math.pi)
        + 0.5 * torch.square((y - mu) / sigma)
    )
    return -ll.mean()

def binary_cross_entropy(logits, y):
    # BCE loss requires squeezing last dimension of logits so it has the same shape as y
    # requires y to be float, since it's overloaded to represent a probability
    return F.binary_cross_entropy_with_logits(logits.squeeze(-1), y.float())


def binary_accuracy(logits, y):
    return torch.eq(logits.squeeze(-1) >= 0, y).float().mean()


def cross_entropy(logits, y):
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    return F.cross_entropy(logits, y)

def cross_entropy_ignore_background(logits, y):
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    return F.cross_entropy(logits, y, ignore_index=0)

def cross_entropy_weighted(logits,y):
    logits = logits.view(-1, logits.shape[-1])
    assert logits.shape[-1] == 16
    y = y.view(-1)
    #class_weights = torch.tensor([0.1, 0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1]).cuda()
    path = '/data/amos22_processed_normalized/class_weights.json'
    with open(path, 'r') as f:
        #saved as a dict, convert to a list
        import json
        class_weights = torch.tensor(list(json.load(f).values())).cuda()
    
    return F.cross_entropy(logits, y, weight=class_weights)

def dice(y_pred, y,smooth=1e-6):
    # y_pred has shape (N, nr_classes)
    # y has shape (N,)
    y = y.view(-1)
    y_pred = y_pred.view(-1, y_pred.shape[-1])
    # Convert y_pred to binary predictions by taking the argmax along axis 1
    _, y_pred_bin = torch.max(y_pred, dim=1)
    
    # Initialize an empty dictionary to store the dice scores for each class
    dice_scores = {}
    
    # Loop over each class
    for class_label in range(y_pred.size(1)):
        # Calculate the intersection between y_pred_bin and y for the current class
        intersection = torch.sum(torch.logical_and(y_pred_bin == class_label, y == class_label))
        
        # Calculate the sum of predictions and ground truth for the current class
        sum_pred = torch.sum(y_pred_bin == class_label)
        sum_gt = torch.sum(y == class_label)
        
        # Calculate the Dice score for the current class
        dice_score = (2 * intersection+smooth) / (sum_pred + sum_gt+smooth)
        dice_scores[class_label] = (dice_score)
        #Get the mean dice score over all classes except the background
        mean_dice_score = ((torch.sum(torch.tensor(list(dice_scores.values())))-dice_scores[0])/len(list(dice_scores.values())))
        dice_scores['meanDice'] = mean_dice_score


        # Store the dice score in the dictionary using the class label as the key
          # Convert tensor to Python scalar
    
    return dice_scores

def dice_loss(predicted, target, smooth=1e-6):
    target = target.view(-1)
    predicted = predicted.view(-1, predicted.shape[-1])
    num_classes = predicted.size(1)  # Number of classes
    dice_scores = torch.zeros(num_classes, device=predicted.device)
    
    
    for i in range(num_classes):
        pred_class = predicted[:, i]
        target_class = (target==i).float()
        
        intersection = torch.sum(pred_class * target_class)
        union = torch.sum(pred_class) + torch.sum(target_class)
        dice_score = (2.0 * intersection + smooth) / (union + smooth)
        dice_scores[i] = dice_score
    
    average_dice = torch.mean(dice_scores)  # Average Dice score across all classes
    loss = 1.0 - average_dice
    return loss

def cross_entropy_with_dice(y_pred, y,lambd=0.5):
    #Compute the dice loss = 1 - mean dice score
    #Compute the cross entropy loss
    #Add the two losses together
    #lambda = relative weight of the dice loss
    assert lambd >= 0 and lambd <= 1

    return lambd*dice_loss(y_pred,y) + (1-lambd)*cross_entropy(y_pred,y)



def ss_log_loss(y_pred, y):
    #y has shape (batch_size,loan_pool_size,num_samples_per_trajectory,num_states)
    #y_pred has shape (batch_size,num_states,num_states)
    bz = y.shape[0]
    loan_pool_size = y.shape[1]
    seq_len = y.shape[2]
    Classes = y.shape[3]

    nll = 0
    y = torch.argmax(y,dim=3)
    previous_states = y[:,:,:-1]
    current_states = y[:,:,1:]
    indices = torch.arange(bz).unsqueeze(1).unsqueeze(2)

    nll = torch.log(y_pred[indices, previous_states, current_states]).sum()

    return -nll / (seq_len * loan_pool_size * bz)

def next_step_log_loss(y_pred, y):
    
    #y has shape (batch_size,,loan_pool_size,num_samples_per_trajectory,num_states)
    #y_pred has shape (loan_pool_size,batch_size,num_samples_per_trajectory,num_states)
    num_states = y_pred.shape[3]
    if y.shape[1]*y.shape[2] == y_pred.shape[0]:
        #we use the set encoder where y_pred stacks the trajectories in the batch dimension
        y = torch.reshape(y, (-1, 1, y.shape[2], y.shape[3]))
    if y.shape[1] == y_pred.shape[0] and y.shape[0] == y_pred.shape[1]:
        y_pred = torch.permute(y_pred,(1,0,2,3))
    assert torch.numel(y) == torch.numel(y_pred)
    
    loan_pool_size = y.shape[2]
    y = torch.argmax(y[:,:,1:,:], dim=3)
    ##
    if num_states == 3:
        num_terminal_transitions = torch.sum(torch.eq(y,2))
    elif num_states == 10:
        num_terminal_transitions = torch.sum(torch.logical_or(torch.eq(y,8),torch.eq(y,9)))
    elif num_states == 8:
        num_terminal_transitions = torch.sum(torch.eq(y,7))
    ##
    y_pred = y_pred[:,:,:-1,:]
    y = y.flatten()
    y_pred = y_pred.flatten(0,2)
    out = F.cross_entropy(y_pred,y)
    #We want to normalize the loss with respect to the number of non-terminal transitions
    #Multiply by the total number of transitions and divide by the number of non-terminal transitions
    #If num_states =3, then the last state is terminal
    #If num_states = 10, then the last 2 states are terminal

    
    
    out = out * (y.shape[0])/(y.shape[0]+1-num_terminal_transitions)
    # Assert out is not complex
    assert not torch.is_complex(out), f"y became complex: dtype={out.dtype}"
    return out


def next_step_log_loss_valid_indices2(y_pred, y, valid_indices):
    
    num_states = y_pred.shape[3]
    
    if y.shape[1]*y.shape[2] == y_pred.shape[0]:
        #we use the set encoder where y_pred stacks the trajectories in the batch dimension
        y = torch.reshape(y, (-1, 1, y.shape[2], y.shape[3]))
    if y.shape[1] == y_pred.shape[0] and y.shape[0] == y_pred.shape[1]:
        y_pred = torch.permute(y_pred,(1,0,2,3))
    #y has shape (batch_size, nr_units, nr_timesteps, num_states)
    #y_pred has shape (batch_size, nr_units, nr_timesteps, num_states)

    #assert y and y_pred have same number of elements
    assert torch.numel(y) == torch.numel(y_pred)
    y = torch.argmax(y[:,:,1:,:], dim=3)
    assert num_states == 8
    num_terminal_transitions = torch.sum(torch.eq(y,7))
    y_pred = y_pred[:,:,:-1,:]
    y = y.flatten()
    y_pred = y_pred.flatten(0,2)
    out = F.cross_entropy(y_pred,y)
    #We want to normalize the loss with respect to the number of non-terminal transitions
    #Multiply by the total number of transitions and divide by the number of non-terminal transitions
    out = out * (y.shape[0])/(y.shape[0]+1-num_terminal_transitions)
    return out


def next_step_log_loss_valid_indices(y_pred, y, valid_indices):
    num_states = y_pred.shape[3]
    
    # -- handle shape quirks etc. --
    if y.shape[1]*y.shape[2] == y_pred.shape[0]:
        y = torch.reshape(y, (-1, 1, y.shape[2], y.shape[3]))
    if y.shape[1] == y_pred.shape[0] and y.shape[0] == y_pred.shape[1]:
        y_pred = torch.permute(y_pred,(1,0,2,3))
    if y.shape[0]*y.shape[1] == y_pred.shape[0]:
        y_pred = torch.reshape(y_pred, y.shape)
    #y has shape (batch_size, nr_units, nr_timesteps, num_states)
    #y_pred has shape (batch_size, nr_units, nr_timesteps, num_states)
    
    assert y.shape == y_pred.shape


    y = y[:, :, 1:, :]       # shift
    y_pred = y_pred[:, :, :-1, :]  # shift

    # map valid_indices -> shifted_valid
    start_valid = valid_indices[0][0]
    end_valid = valid_indices[0][-1]-1
    assert start_valid >=0
    assert end_valid <= y.shape[2]
    #shifted_valid = [v - 1 for v in valid_indices[0] if v >= 1]
    #max_time = y.shape[2]
    #shifted_valid = [sv for sv in shifted_valid if sv < max_time]

    if start_valid >= end_valid:
        # Return a 0.0 that still requires grad
        breakpoint()
        return torch.tensor(0.0, 
                            dtype=torch.float32, 
                            device=y_pred.device, 
                            requires_grad=True)

    # gather time steps
    y = torch.argmax(y, dim=3)  # shape => (B, U, T-1)
    y = y[:, :, start_valid:end_valid]  # shape => (B, U, valid_len)
    y_pred = y_pred[:, :, start_valid:end_valid, :]  # (B, U, valid_len, num_states)

    num_terminal_transitions = (y == 7).sum()

    # flatten
    
    y_flat = y.flatten()
    y_pred_flat = y_pred.flatten(0,2)

    out = F.cross_entropy(y_pred_flat, y_flat)

    total_transitions = y_flat.shape[0]
    non_terminal = total_transitions - num_terminal_transitions
    non_terminal = torch.clamp(non_terminal, min=1)

    out = out * (total_transitions / non_terminal)

    return out


def next_step_log_loss_ignore_terminal(y_pred, y):
    
    #y has shape (batch_size,,loan_pool_size,num_samples_per_trajectory,num_states)
    #y_pred has shape (loan_pool_size,batch_size,num_samples_per_trajectory,num_states)
    num_states = y_pred.shape[3]
    if y.shape[1]*y.shape[2] == y_pred.shape[0]:
        #we use the set encoder where y_pred stacks the trajectories in the batch dimension
        y = torch.reshape(y, (-1, 1, y.shape[2], y.shape[3]))
    if y.shape[1] == y_pred.shape[0] and y.shape[0] == y_pred.shape[1]:
        y_pred = torch.permute(y_pred,(1,0,2,3))
    assert torch.numel(y) == torch.numel(y_pred)
    
    y = torch.argmax(y[:,:,1:,:], dim=3)
    y_pred = y_pred[:,:,:-1,:]
    y = y.flatten()
    y_pred = y_pred.flatten(0,2)
    #weights = torch.tensor([5,1,1,1,1,1,1,0]).float().cuda() # see if this will improve the prepayment prediction
    #out = F.cross_entropy(y_pred,y,ignore_index=num_states-1, size_average=True, reduction='mean',weight=weights) 
    out = F.cross_entropy(y_pred,y,ignore_index=num_states-1, size_average=True, reduction='mean') 
    #We want to normalize the loss with respect to the number of non-terminal transitions
    #Multiply by the total number of transitions and divide by the number of non-terminal transitions
    #If num_states =3, then the last state is terminal
    #If num_states = 10, then the last 2 states are terminal
    return out

def excess_Xentropy_loss(y_pred,y,M):
    return next_step_log_loss(y_pred,y)-next_step_log_loss_GT(y_pred,y,M)

def log_loss_not_using_path_variable(y_pred,y,M):
    return 1.234

def performance_decay(y_pred,y,M,I):
    if (900 not in I) and (600 not in I):
        return 0
    else:
        indices_beginning = torch.nonzero(torch.eq(I, 600)).flatten()
        indices_end = torch.nonzero(torch.eq(I, 900)).flatten()
        y_pred_s,y_s = y_pred[indices_beginning,:,:,:], y[indices_beginning,:,:,:]
        y_pred_e,y_e = y_pred[indices_end,:,:,:], y[indices_end,:,:,:]
        M_s, M_e = M[indices_beginning,:,:,:], M[indices_end,:,:,:]
        if len(indices_end)==0:
            return -excess_Xentropy_loss(y_pred_s,y_s,M_s)
        elif len(indices_beginning) ==0:
            return excess_Xentropy_loss(y_pred_e,y_e,M_e)
        else:
            (len(indices_end)*excess_Xentropy_loss(y_pred_e,y_e,M_e)-len(indices_beginning)*excess_Xentropy_loss(y_pred_s,y_s,M_s))/I.shape[0]


def next_step_log_loss_GT(y_pred, y,M):
    #y has shape (batch_size,,loan_pool_size,num_samples_per_trajectory,num_states)
    #y_pred has shape (batch_size,loan_pool_size,num_samples_per_trajectory,num_states)
    num_states = y_pred.shape[3]
    y2 = torch.argmax(y[:,:,1:,:], dim=3)
    if num_states == 3:
        num_terminal_transitions = torch.sum(torch.eq(y2,2))
    elif num_states == 10:
        num_terminal_transitions = torch.sum(torch.logical_or(torch.eq(y2,8),torch.eq(y2,9)))
    ##
    out,_ = torch.max((M * y)[:,:,1:,:],dim=3)
    out = torch.log(out)
    out = -torch.mean(out)
    #Normalize loss by number of non-terminal transitions
    out = out * (torch.numel(y2))/(torch.numel(y2)+1-num_terminal_transitions)
    return out


def next_step_mse(y_pred, y):
    
    #y has shape (batch_size,,loan_pool_size,num_samples_per_trajectory,num_states)
    #y_pred has shape (loan_pool_size,batch_size,num_samples_per_trajectory,num_states)
    if y.shape[1]*y.shape[2] == y_pred.shape[0]:
        #we use the set encoder where y_pred stacks the trajectories in the batch dimension
        y = torch.reshape(y, (-1, 1, y.shape[2], y.shape[3]))
    if y.shape[1] == y_pred.shape[0] and y.shape[0] == y_pred.shape[1]:
        y_pred = torch.permute(y_pred,(1,0,2,3))
    assert torch.numel(y) == torch.numel(y_pred)
    #y = torch.argmax(y[:,:,1:,:], dim=3)
    y = y[:,:,1:,:]
    y_pred = y_pred[:,:,:-1,:]
    y = y.flatten()
    #y_pred = y_pred.flatten(0,2)
    y_pred = y_pred.flatten()
    out = F.mse_loss(y_pred.float(),y.float())
    return out

import torch
import torch.nn.functional as F

def res_loss(y_pred, y, alpha=0.1, eps=1e-8, **kwargs):
    e = kwargs['state'][0]["explained_variance"]
    #e_portfolio = kwargs['state'][0]["explained_variance_portfolio"]
    # get L2 norm of e
    #e_norm = torch.norm(e,p=2,dim=0)
    # get mean of e_norm
    #e_norm_mean = e_norm.mean()
    return {"explained_variance": e} #, "explained_variance_portfolio": e_portfolio}

def next_step_sharpe_ratio_with_residual_loss(y_pred, y, alpha=0.1, eps=1e-8, **kwargs):
    
    """
    Computes a Sharpe-ratio–based loss with an L1 penalty enforcing that
    the sum of portfolio weights is close to 1.

    Args:
        y_pred (Tensor): Predicted portfolio weights.
            Shape (nr_assets, batch_size, nr_timesteps, 1) 
            in some configurations, or possibly stacked in batch dim.
        y (Tensor): Asset returns (or targets).
            Shape (batch_size, nr_assets, nr_timesteps, 1).

        alpha (float): Strength of the L1 penalty term on sum(weights) - 1.
        eps (float): Small constant for numerical stability in std calculation.

    Returns:
        Tensor: A scalar loss; negative Sharpe ratio plus penalty.
    """
    
    explained_variance = kwargs['state'][0]["explained_variance"]
    factor_portfolio = kwargs['state'][0]["factor_portfolio"]
    factor_portfolio = torch.permute(factor_portfolio, (0, 2, 1))
    factor_portfolio = torch.permute(factor_portfolio, (1,0, 2))
    factor_portfolio = torch.unsqueeze(factor_portfolio, dim=-1)
    
    y_pred_weight = 1
    factor_portfolio_weight = 1
    y_pred = y_pred * y_pred_weight + factor_portfolio * factor_portfolio_weight
    leverage = torch.norm(y_pred,p=1,dim=0).mean()
    goal_leverage = 1 
    leverage_penalty = torch.abs(leverage - goal_leverage)

    # Hyperparameters for the different penalty terms.
    alpha_leverage = 0 #0.003       # Weight penalty coefficient.
    alpha_sharpe = 0.01  # Sharpe-related term weight.
    alpha_turnover = 0 #0.005 #0.005  # Turnover penalty weight.
    alpha_transaction_costs = 1
    use_transaction_costs = True
    # Compute the turnover as the average L1-norm difference in weights between consecutive timesteps.
    # Here, we assume the timesteps are along dimension 2.
    turnover = torch.mean(torch.norm(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :], p=1, dim=0))
    
    # Combine the loss terms:
    #  - Negative Sharpe ratio term,
    #  - L1 penalty for weight normalization,
    #  - Penalty for turnover changes,
    #  - And subtract the explained variance.
    loss = (-alpha_sharpe * next_step_sharpe_ratio_positive(y_pred, y, alpha, eps,transaction_costs=use_transaction_costs, alpha_transaction_costs=alpha_transaction_costs)
            + alpha_leverage * leverage_penalty
            + alpha_turnover * turnover
            - explained_variance)
    return loss

def next_step_sharpe_ratio_with_transaction_cost(y_pred, y, alpha=0.1, eps=1e-8):
    return -next_step_sharpe_ratio_positive(y_pred,y,alpha,eps,transaction_costs=True)

def next_step_sharpe_ratio(y_pred, y, alpha=0.1, eps=1e-8):
    """
    Computes a Sharpe-ratio–based loss with an L1 penalty enforcing that
    the sum of portfolio weights is close to 1.

    Args:
        y_pred (Tensor): Predicted portfolio weights.
            Shape (nr_assets, batch_size, time_steps, 1) 
            in some configurations, or possibly stacked in batch dim.
        y (Tensor): Asset returns (or targets).
            Shape (batch_size, nr_assets, time_steps, 1).

        alpha (float): Strength of the L1 penalty term on sum(weights) - 1.
        eps (float): Small constant for numerical stability in std calculation.

    Returns:
        Tensor: A scalar loss; negative Sharpe ratio plus penalty.
    """
    return -next_step_sharpe_ratio_positive(y_pred,y,alpha,eps) 
    
    
def next_step_sharpe_ratio_positive(y_pred, y, alpha=0.01, eps=1e-8, transaction_costs=False, alpha_transaction_costs=1):
    """
    Computes a Sharpe-ratio–based loss with an L1 penalty enforcing that
    the sum of portfolio weights is close to 1.

    Args:
        y_pred (Tensor): Predicted portfolio weights.
            Shape (num_assets, batch_size, num_samples, num_states =1) 
            in some configurations, or possibly stacked in batch dim.
        y (Tensor): Asset returns (or targets).
            Shape (batch_size, num_assets, num_samples, num_states).

        alpha (float): Strength of the L1 penalty term on sum(weights) - 1.
        eps (float): Small constant for numerical stability in std calculation.

    Returns:
        Tensor: A scalar loss; negative Sharpe ratio plus penalty.
    """
    if y.shape[0]*y.shape[1] == y_pred.shape[0]:
        y_pred = torch.reshape(y_pred, y.shape)
    # Make sure total number of elements match (this is a sanity check).
    
    assert torch.numel(y) == torch.numel(y_pred), \
        f"Shape mismatch: numel(y)={torch.numel(y)}, numel(y_pred)={torch.numel(y_pred)}"
    assert y.shape == y_pred.shape
    # --- Step 2: Slice the same way you did for MSE (adjust if your indexing differs) ---
    # e.g. skip the first 'state' in y, skip the last 'state' in y_pred
    

    y = y[:, :, 1:, :]
    y_pred = y_pred[:, :, :-1, :]

    portfolio_returns = (y_pred * y).sum(dim=1)  
    
    if transaction_costs:
        # Calculate daily turnover: L1 norm of weight changes between consecutive timesteps
        w_t = y_pred[:, :, 1:, :]
        w_tprev = y_pred[:, :, :-1, :]
        
        # First part of the cost: 0.0005 × ||w_t - w_{t-1}||_1
        daily_turnover = (w_t - w_tprev).abs().sum(dim=1)  # L1 norm along asset dimension
        turnover_costs = 0.0005 * daily_turnover
        
        # Second part of the cost: 0.0001 × ||max(-w_t, 0)||_1
        # Get the magnitude of short positions (max(-w_t, 0))
        short_positions = torch.maximum(-y_pred, torch.zeros_like(y_pred))
        # Alternative with short fraction
        short_costs = 0.0001 * short_positions.sum(dim=1)  # L1 norm along asset dimension
        
        # Pad the turnover costs to match portfolio returns shape
        padding = torch.zeros_like(portfolio_returns[:, :1])
        turnover_costs_padded = torch.cat([padding, turnover_costs], dim=1)
        
        # Adjust portfolio returns by subtracting both costs
        portfolio_returns = portfolio_returns - alpha_transaction_costs * (turnover_costs_padded + short_costs)

    # Flatten out everything to treat as one large sample set
    portfolio_returns = portfolio_returns.flatten(start_dim=0)  # 1D vector of returns
    
    # --- Step 4: Compute the Sharpe ratio ---
    mean_returns = portfolio_returns.mean()
    std_returns = portfolio_returns.std() + eps  # add eps to avoid divide-by-zero
    sharpe_ratio = mean_returns / std_returns

    return sharpe_ratio*torch.sqrt(torch.tensor(252.0))

def compute_portfolio_metrics(y_pred, y, eps=1e-8, return_raw_only=False, **kwargs):
    """
    Computes various portfolio metrics from predicted weights and actual returns.

    Args:
        y_pred (Tensor): Predicted portfolio weights.
        y (Tensor): Asset returns (or targets).
        eps (float): Small constant for numerical stability.
        return_raw_only (bool): If True, only calculate and return raw returns tensor.

    Returns:
        dict or Tensor: If return_raw_only is True, returns only the raw_returns tensor.
                       Otherwise returns a dictionary of metrics.
    """
    try:
        factor_portfolio = kwargs['state'][0]["factor_portfolio"]
        factor_portfolio = torch.permute(factor_portfolio, (0, 2, 1))
        factor_portfolio = torch.permute(factor_portfolio, (1,0, 2))
        factor_portfolio = torch.unsqueeze(factor_portfolio, dim=-1)
        weight_factor_portfolio = kwargs['state'][0]["l1_norm_factor_portfolio_weight"]
        y_pred_weight = 1
        factor_portfolio_weight = 1
        y_res = y_pred
        y_pred = y_res * y_pred_weight + factor_portfolio * factor_portfolio_weight
        y_res = torch.reshape(y_res, y.shape)
        y_factor = torch.reshape(factor_portfolio, y.shape)
        
        # Calculate separate metrics for residual and factor components
        use_factor_portfolio = True
    except:
        use_factor_portfolio = False
        pass
    # --- 1) Reshape y_pred to match y if needed ---
    if y.shape[0] * y.shape[1] == y_pred.shape[0]:
        y_pred = torch.reshape(y_pred, y.shape)

    # Sanity check
    assert torch.numel(y) == torch.numel(y_pred), \
        f"Shape mismatch: numel(y)={torch.numel(y)}, numel(y_pred)={torch.numel(y_pred)}"
    assert y.shape == y_pred.shape, \
        f"Shapes differ after reshape: y.shape={y.shape}, y_pred.shape={y_pred.shape}"
    

    # --- 2) Slice off the first 'state' from y and the last 'state' from y_pred ---
    y = y[:, :, 1:, :]
    y_pred = y_pred[:, :, :-1, :]
    # y and y_pred have shape (batch_size, nr_assets, nr_timesteps, 1)
    
    # --- 3) Compute portfolio returns ---
    portfolio_returns = (y_pred * y).sum(dim=1)
    
    # Flatten returns for easier processing
    flat_returns = portfolio_returns.flatten(start_dim=0)

    n_assets = y.shape[1]
    equal_weights = torch.ones_like(y_pred) / n_assets
    
    # Compute market portfolio returns
    market_returns = (equal_weights * y).sum(dim=1)
    flat_market_returns = market_returns.flatten(start_dim=0)
    
    # --- Compute beta: covariance(portfolio, market) / variance(market) ---
    market_variance = flat_market_returns.var() + eps
    portfolio_market_cov = torch.mean((flat_returns - flat_returns.mean()) * 
                                       (flat_market_returns - flat_market_returns.mean()))
    portfolio_beta = portfolio_market_cov / market_variance
    
    # If we only need raw returns, also compute market portfolio returns
    if return_raw_only:
        # Return both portfolio returns and market portfolio returns
        return flat_returns, flat_market_returns
    
    # --- Calculate full metrics if needed ---
    mean_return = flat_returns.mean()
    std_return = flat_returns.std() + eps

    # --- 4) Annualize return, volatility, and compute Sharpe. ---
    # Assume 252 trading days per year, so:
    yearly_return = mean_return * 252
    yearly_vol = std_return * math.sqrt(252)
    yearly_sharpe = yearly_return / (yearly_vol + eps)

    # --- 5) Average leverage = average sum of absolute weights. ---
    #     Before we flattened or sliced, y_pred was (B, P, T, A).
    #     After slicing, it's still (B, P, T, A). We'll compute:
    #        daily_leverage[b, p, t] = sum_j |w[b, p, t, j]|
    #     Then average over all b, p, t.
    daily_leverage = y_pred.abs().sum(dim=1)  # shape: (B, P, T)
    avg_leverage = daily_leverage.mean()

    w_t = y_pred[:, :, 1:, :]
    w_tprev = y_pred[:, :, :-1, :]
    daily_turnover = (w_t - w_tprev).abs().sum(dim=1)  # shape: (B, P, T-1)
    avg_turnover = daily_turnover.mean()
    
    # --- 6) Fraction of short positions: proportion of negative weights overall ---
    # Note: also sliced like y_pred, so it matches the same time frames used above.
    short_mask = (y_pred < 0).float()
    frac_short_positions = short_mask.mean()
    
    # --- Calculate Sharpe ratio with transaction costs applied per timestep ---
    # First compute portfolio returns with transaction costs
    
    # For short costs: ||max(-w_t, 0)||_1
    short_positions = torch.maximum(-y_pred, torch.zeros_like(y_pred))
    short_costs = 0.0001 * short_positions.sum(dim=1)  # L1 norm along asset dimension
    
    # For turnover: 0.0005 × ||w_t - w_{t-1}||_1
    # Pad the turnover costs to match portfolio returns shape
    padding = torch.zeros_like(portfolio_returns[:, :1])
    turnover_costs_padded = torch.cat([padding, daily_turnover * 0.0005], dim=1)
    
    # Apply transaction costs to portfolio returns
    portfolio_returns_after_costs = portfolio_returns - (turnover_costs_padded + short_costs)
    
    # Calculate Sharpe ratio with these adjusted returns
    flat_returns_after_costs = portfolio_returns_after_costs.flatten(start_dim=0)
    mean_return_after_costs = flat_returns_after_costs.mean()
    std_return_after_costs = flat_returns_after_costs.std() + eps
    yearly_return_after_costs = mean_return_after_costs * 252
    yearly_vol_after_costs = std_return_after_costs * math.sqrt(252)
    sharpe_ratio_transaction_costs = yearly_return_after_costs / (yearly_vol_after_costs + eps)

    # --- 7) Return in a dictionary. ---
    median_weights = y_pred.median()
    percentile_99_weights = y_pred.quantile(0.99)
    percentile_1_weights = y_pred.quantile(0.01)
    max_weights = y_pred.max()
    min_weights = y_pred.min()
    percentile_90_weights = y_pred.quantile(0.90)
    percentile_10_weights = y_pred.quantile(0.10)
    # Flatten y and y_pred
    flat_y = y.flatten(start_dim=0)
    flat_y_pred = y_pred.flatten(start_dim=0)

    y_new = y[0,:,:,0]
    y_new_pred = y_pred[0,:,:,0]
    # compute correlation between each asset separately
    correlations = []
    for i in range(y_new.shape[1]):
        correlation = torch.corrcoef(torch.stack((y_new[:,i], y_new_pred[:,i])))[0][1]
        correlations.append(correlation)
    max_correlation = max(correlations)
    min_correlation = min(correlations)
    percentile_90_correlation = torch.quantile(torch.tensor(correlations), 0.90)
    percentile_10_correlation = torch.quantile(torch.tensor(correlations), 0.10)
    median_correlation = torch.median(torch.tensor(correlations))
    percentile_99_correlation = torch.quantile(torch.tensor(correlations), 0.99)
    percentile_1_correlation = torch.quantile(torch.tensor(correlations), 0.01)
    # Compute the correlation between y and y_pred
    correlation = torch.corrcoef(torch.stack((flat_y, flat_y_pred)))[0][1]

    # Compute additional metrics for residual and factor components if available
    if use_factor_portfolio:
        # Slice the components like we did for the combined portfolio
        y_res_sliced = y_res[:, :, :-1, :]
        y_factor_sliced = y_factor[:, :, :-1, :]
        y_sliced = y
        
        # Compute portfolio returns for each component
        residual_returns = (y_res_sliced * y_sliced).sum(dim=1)
        factor_returns = (y_factor_sliced * y_sliced).sum(dim=1)
        
        # Flatten returns for easier processing
        flat_residual_returns = residual_returns.flatten(start_dim=0)
        flat_factor_returns = factor_returns.flatten(start_dim=0)
        
        # Calculate metrics for residual component
        mean_residual_return = flat_residual_returns.mean()
        std_residual_return = flat_residual_returns.std() + eps
        yearly_residual_return = mean_residual_return * 252
        yearly_residual_vol = std_residual_return * math.sqrt(252)
        yearly_residual_sharpe = yearly_residual_return / (yearly_residual_vol + eps)
        
        # Calculate metrics for factor component
        mean_factor_return = flat_factor_returns.mean()
        std_factor_return = flat_factor_returns.std() + eps
        yearly_factor_return = mean_factor_return * 252
        yearly_factor_vol = std_factor_return * math.sqrt(252)
        yearly_factor_sharpe = yearly_factor_return / (yearly_factor_vol + eps)
        
        # Calculate beta for residual component
        residual_market_cov = torch.mean((flat_residual_returns - flat_residual_returns.mean()) * 
                                         (flat_market_returns - flat_market_returns.mean()))
        residual_beta = residual_market_cov / market_variance
        
        # Calculate beta for factor component
        factor_market_cov = torch.mean((flat_factor_returns - flat_factor_returns.mean()) * 
                                       (flat_market_returns - flat_market_returns.mean()))
        factor_beta = factor_market_cov / market_variance
        
        # Calculate turnover for residual component
        w_res_t = y_res_sliced[:, :, 1:, :]
        w_res_tprev = y_res_sliced[:, :, :-1, :]
        residual_daily_turnover = (w_res_t - w_res_tprev).abs().sum(dim=1)
        avg_residual_turnover = residual_daily_turnover.mean()
        
        # Calculate turnover for factor component
        w_factor_t = y_factor_sliced[:, :, 1:, :]
        w_factor_tprev = y_factor_sliced[:, :, :-1, :]
        factor_daily_turnover = (w_factor_t - w_factor_tprev).abs().sum(dim=1)
        avg_factor_turnover = factor_daily_turnover.mean()
        
        # Add to return dictionary
        return_dict = {
            'yearly_sharpe': yearly_sharpe,
            'yearly_sharpe_w_transaction_costs': sharpe_ratio_transaction_costs,
            'yearly_return': yearly_return,
            'yearly_vol': yearly_vol,
            'avg_leverage': avg_leverage,
            'avg_turnover': avg_turnover,
            'frac_short_positions': frac_short_positions,
            'median_weights': median_weights,
            'percentile_99_weights': percentile_99_weights,
            'percentile_1_weights': percentile_1_weights,
            'max_weights': max_weights,
            'min_weights': min_weights,
            'percentile_90_weights': percentile_90_weights,
            'percentile_10_weights': percentile_10_weights,
            'overall_correlation': correlation,
            'max_correlation': max_correlation,
            'min_correlation': min_correlation,
            'percentile_90_correlation': percentile_90_correlation,
            'percentile_10_correlation': percentile_10_correlation,
            'median_correlation': median_correlation,
            'percentile_99_correlation': percentile_99_correlation,
            'portfolio_beta': portfolio_beta,
            # Add component metrics
            'yearly_residual_sharpe': yearly_residual_sharpe,
            'yearly_residual_return': yearly_residual_return,
            'yearly_residual_vol': yearly_residual_vol,
            'residual_beta': residual_beta,
            'avg_residual_turnover': avg_residual_turnover,
            'yearly_factor_sharpe': yearly_factor_sharpe,
            'yearly_factor_return': yearly_factor_return,
            'yearly_factor_vol': yearly_factor_vol,
            'factor_beta': factor_beta,
            'avg_factor_turnover': avg_factor_turnover,
            'weight_factor_portfolio': weight_factor_portfolio,
        }
        return return_dict
    else:
        return {
            'yearly_sharpe': yearly_sharpe,
            'yearly_sharpe_w_transaction_costs': sharpe_ratio_transaction_costs,
            'yearly_return': yearly_return,
            'yearly_vol': yearly_vol,
            'avg_leverage': avg_leverage,
            'avg_turnover': avg_turnover,
            'frac_short_positions': frac_short_positions,
            'median_weights': median_weights,
            'percentile_99_weights': percentile_99_weights,
            'percentile_1_weights': percentile_1_weights,
            'max_weights': max_weights,
            'min_weights': min_weights,
            'percentile_90_weights': percentile_90_weights,
            'percentile_10_weights': percentile_10_weights,
            'overall_correlation': correlation,
            'max_correlation': max_correlation,
            'min_correlation': min_correlation,
            'percentile_90_correlation': percentile_90_correlation,
            'percentile_10_correlation': percentile_10_correlation,
            'median_correlation': median_correlation,
            'percentile_99_correlation': percentile_99_correlation,
            'portfolio_beta': portfolio_beta,
        }

def sharpe_ratio_market(y_pred, y, alpha=0.1, eps=1e-8):
    """
    Computes a Sharpe-ratio–based loss with an L1 penalty enforcing that
    the sum of portfolio weights is close to 1.

    Args:
        y_pred (Tensor): Predicted portfolio weights.
            Shape (loan_pool_size, batch_size, num_samples, num_states) 
            in some configurations, or possibly stacked in batch dim.
        y (Tensor): Asset returns (or targets).
            Shape (batch_size, loan_pool_size, num_samples, num_states).

        alpha (float): Strength of the L1 penalty term on sum(weights) - 1.
        eps (float): Small constant for numerical stability in std calculation.

    Returns:
        Tensor: A scalar loss; negative Sharpe ratio plus penalty.
    """
    
    ynew = y[0,:,:,0]
    mean_ret = ynew.mean(axis=0)
    mean_r = mean_ret.mean()
    std_r = mean_ret.std()
    sharpe = mean_r / std_r
    sharpe_market = sharpe*torch.sqrt(torch.tensor(252.0))
    return sharpe_market

    
def soft_cross_entropy(logits, y, label_smoothing=0.0):
    logits = logits.view(-1, logits.shape[-1])
    # target is now 2d (no target flattening)
    return F.cross_entropy(logits, y, label_smoothing=label_smoothing)

def accuracy(logits, y):
    logits = logits.view(-1, logits.shape[-1])
    if y.numel() > logits.shape[0]:
        # Mixup leads to this case: use argmax class
        y = y.argmax(dim=-1)
    y = y.view(-1)
    return torch.eq(torch.argmax(logits, dim=-1), y).float().mean()

def accuracy_ignore_index(logits, y, ignore_index=-100):
    num_classes = logits.shape[-1]
    preds = torch.argmax(logits, dim=-1)
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    return tm_f.classification.accuracy(preds, y, num_classes=num_classes, ignore_index=ignore_index, average='micro')


def accuracy_at_k(logits, y, k=1):
    logits = logits.view(-1, logits.shape[-1])
    if y.numel() > logits.shape[0]:
        # Mixup leads to this case: use argmax class
        y = y.argmax(dim=-1)
    y = y.view(-1)
    return torch.topk(logits, k, dim=-1)[1].eq(y.unsqueeze(-1)).any(dim=-1).float().mean()


def f1_binary(logits, y):
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    y_hat = torch.argmax(logits, dim=-1)
    return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="binary")


def f1_macro(logits, y):
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    y_hat = torch.argmax(logits, dim=-1)
    return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="macro")


def f1_micro(logits, y):
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    y_hat = torch.argmax(logits, dim=-1)
    return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro")


def roc_auc_macro(logits, y):
    logits = logits.view(
        -1, logits.shape[-1]
    ).detach()  # KS: had to add detach to eval while training
    y = y.view(-1)
    return roc_auc_score(
        y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="macro"
    )


def roc_auc_micro(logits, y):
    logits = logits.view(-1, logits.shape[-1])
    y = y.view(-1)
    return roc_auc_score(
        y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="micro"
    )


def mse(outs, y, len_batch=None):
    # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1
    # outs = outs.squeeze(-1)
    
    if len(y.shape) < len(outs.shape):
        assert outs.shape[-1] == 1
        outs = outs.squeeze(-1)
    if len_batch is None:
        return F.mse_loss(outs, y)
    else:
        # Computes the loss of the first `lens` items in the batches
        # TODO document the use case of this
        mask = torch.zeros_like(outs, dtype=torch.bool)
        for i, l in enumerate(len_batch):
            mask[i, :l, :] = 1
        outs_masked = torch.masked_select(outs, mask)
        y_masked = torch.masked_select(y, mask)
        return F.mse_loss(outs_masked, y_masked)

def forecast_rmse(outs, y, len_batch=None):
    # TODO: generalize, currently for Monash dataset
    return torch.sqrt(F.mse_loss(outs, y, reduction='none').mean(1)).mean()

def mae(outs, y, len_batch=None):
    # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1
    # outs = outs.squeeze(-1)
    if len(y.shape) < len(outs.shape):
        assert outs.shape[-1] == 1
        outs = outs.squeeze(-1)
    if len_batch is None:
        return F.l1_loss(outs, y)
    else:
        # Computes the loss of the first `lens` items in the batches
        mask = torch.zeros_like(outs, dtype=torch.bool)
        for i, l in enumerate(len_batch):
            mask[i, :l, :] = 1
        outs_masked = torch.masked_select(outs, mask)
        y_masked = torch.masked_select(y, mask)
        return F.l1_loss(outs_masked, y_masked)
    


# Metrics that can depend on the loss
def loss(x, y, loss_fn):
    """ This metric may be useful because the training loss may add extra regularization (e.g. weight decay implemented as L2 penalty), while adding this as a metric skips the additional losses """
    return loss_fn(x, y)


def bpb(x, y, loss_fn):
    """ bits per byte (image density estimation, speech generation, char LM) """
    return loss_fn(x, y) / math.log(2)


def ppl(x, y, loss_fn):
    return torch.exp(loss_fn(x, y))


# should have a better way to do this
output_metric_fns = {
    "binary_cross_entropy": binary_cross_entropy,
    "cross_entropy": cross_entropy,
    "binary_accuracy": binary_accuracy,
    "accuracy": accuracy,
    "accuracy_ignore_index": accuracy_ignore_index,
    'accuracy@3': partial(accuracy_at_k, k=3),
    'accuracy@5': partial(accuracy_at_k, k=5),
    'accuracy@10': partial(accuracy_at_k, k=10),
    "eval_loss": loss,
    "mse": mse,
    "mae": mae,
    "forecast_rmse": forecast_rmse,
    "f1_binary": f1_binary,
    "f1_macro": f1_macro,
    "f1_micro": f1_micro,
    "roc_auc_macro": roc_auc_macro,
    "roc_auc_micro": roc_auc_micro,
    "soft_cross_entropy": soft_cross_entropy,  # only for pytorch 1.10+
    "student_t": student_t_loss,
    "gaussian_ll": gaussian_ll_loss,
    "cross_entropy_ignore_background": cross_entropy_ignore_background, #multi-dimensional cross entropy (image segmentation)
    "dice": dice,
    "cross_entropy_weighted": cross_entropy_weighted,
    "dice_loss": dice_loss,
    "cross_entropy_with_dice": cross_entropy_with_dice,
    "ss_log_loss": ss_log_loss,
    "next_step_log_loss": next_step_log_loss,
    "next_step_log_loss_GT": next_step_log_loss_GT,
    "excess_Xentropy_loss": excess_Xentropy_loss,
    "performance_decay": performance_decay,
    "log_loss_not_using_path_variable": log_loss_not_using_path_variable,
    "next_step_log_loss_ignore_terminal": next_step_log_loss_ignore_terminal,
    "next_step_log_loss_valid_indices": next_step_log_loss_valid_indices,
    "next_step_mse": next_step_mse,
    "next_step_sharpe_ratio_positive": next_step_sharpe_ratio_positive,
    "next_step_sharpe_ratio": next_step_sharpe_ratio,
    "next_step_sharpe_ratio_with_residual_loss": next_step_sharpe_ratio_with_residual_loss,
    "res_loss": res_loss,
    "sharpe_ratio_market": sharpe_ratio_market,
    "compute_portfolio_metrics": compute_portfolio_metrics,
    "next_step_sharpe_ratio_with_transaction_cost": next_step_sharpe_ratio_with_transaction_cost,
}

try:
    from segmentation_models_pytorch.utils.functional import iou
    from segmentation_models_pytorch.losses.focal import focal_loss_with_logits

    def iou_with_logits(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
        return iou(pr.sigmoid(), gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels)

    output_metric_fns["iou"] = partial(iou, threshold=0.5)
    output_metric_fns["iou_with_logits"] = partial(iou_with_logits, threshold=0.5)
    output_metric_fns["focal_loss"] = focal_loss_with_logits
except ImportError:
    pass

loss_metric_fns = {
    "loss": loss,
    "bpb": bpb,
    "ppl": ppl,
}
metric_fns = {**output_metric_fns, **loss_metric_fns}  # TODO py3.9

from torchmetrics import Metric

class PortfolioSharpeRatioMetric(Metric):
    def __init__(self, annualization_factor=252.0, eps=1e-8):
        super().__init__(dist_sync_on_step=False)
        self.annualization_factor = annualization_factor
        self.eps = eps
        
        # Register buffers to collect all returns across batches
        self.add_state("all_returns", default=[], dist_reduce_fx=None)
        self.add_state("all_market_returns", default=[], dist_reduce_fx=None)
        
    def update(self, y_pred, y):
        """Add batch raw returns to our collection"""
        # Use the optimized version to get both portfolio and market returns
        portfolio_returns, market_returns = compute_portfolio_metrics(
            y_pred, y, eps=self.eps, return_raw_only=True
        )
        self.all_returns.append(portfolio_returns.detach())
        self.all_market_returns.append(market_returns.detach())
    
    def compute(self):
        """Compute Sharpe ratio on entire collection of returns"""
        if not self.all_returns:
            return torch.tensor(0.0)
            
        # Concatenate all returns from all batches
        combined_returns = torch.cat(self.all_returns, dim=0)
        
        # Compute statistics for portfolio
        mean_returns = combined_returns.mean()
        std_returns = combined_returns.std() + self.eps
        
        # Calculate Sharpe ratio and annualize
        portfolio_sharpe = mean_returns / std_returns
        annual_portfolio_sharpe = portfolio_sharpe * torch.sqrt(torch.tensor(self.annualization_factor))
        
        return annual_portfolio_sharpe

class MarketSharpeRatioMetric(Metric):
    def __init__(self, annualization_factor=252.0, eps=1e-8):
        super().__init__(dist_sync_on_step=False)
        self.annualization_factor = annualization_factor
        self.eps = eps
        
        # Register buffer to collect all returns across batches
        self.add_state("all_market_returns", default=[], dist_reduce_fx=None)
        
    def update(self, y_pred, y):
        """Add batch raw returns to our collection"""
        # Use the optimized version to get market returns
        _, market_returns = compute_portfolio_metrics(
            y_pred, y, eps=self.eps, return_raw_only=True
        )
        self.all_market_returns.append(market_returns.detach())
    
    def compute(self):
        """Compute market Sharpe ratio on entire collection of returns"""
        if not self.all_market_returns:
            return torch.tensor(0.0)
            
        # Concatenate all returns from all batches
        combined_returns = torch.cat(self.all_market_returns, dim=0)
        
        # Compute statistics
        mean_returns = combined_returns.mean()
        std_returns = combined_returns.std() + self.eps
        
        # Calculate Sharpe ratio and annualize
        market_sharpe = mean_returns / std_returns
        annual_market_sharpe = market_sharpe * torch.sqrt(torch.tensor(self.annualization_factor))
        
        return annual_market_sharpe

