import torch
import torch.nn as nn
import torch.nn.functional as F

def calculate_culumative_rewards(reward_tensor, gamma=0.9):
    """
    This function helps to calculate the cumulative return for each step for the reward tensor

    Args:
        reward_tensor (Torch.Tensor): Shape [batch_size, context_lengh, 1]
    """
    batch_size, context_length, _ = reward_tensor.shape
    calculated_result = torch.zeros_like(reward_tensor).to(reward_tensor.device)
    cumulative_rewards = torch.zeros(batch_size, 1).to(reward_tensor.device)
    for i in range(context_length - 1, -1, -1):
        current_rewards = reward_tensor[:, i, :] + gamma * cumulative_rewards
        calculated_result[:, i, :] = current_rewards
        cumulative_rewards = current_rewards
    
    return calculated_result

class LossFunctions:
    """Collection of loss functions for training models."""

    @staticmethod
    def mse_loss(outputs, targets):
        """
        Mean squared error loss.
        
        Args:
            outputs (torch.Tensor): Model predictions
            targets (torch.Tensor): Ground truth values
            
        Returns:
            torch.Tensor: Loss value
        """
        return F.mse_loss(outputs, targets)
    
    @staticmethod
    def bce_loss(outputs, targets):
        """
        Binary cross entropy loss.
        
        Args:
            outputs (torch.Tensor): Model predictions
            targets (torch.Tensor): Ground truth values
            
        Returns:
            torch.Tensor: Loss value
        """
        return F.binary_cross_entropy(outputs, targets)
    
    @staticmethod
    def cross_entropy_loss(outputs, targets):
        """
        Cross entropy loss (for classification).
        
        Args:
            outputs (torch.Tensor): Model predictions (logits)
            targets (torch.Tensor): Ground truth class indices
            
        Returns:
            torch.Tensor: Loss value
        """
        return F.cross_entropy(outputs, targets)
    
    @staticmethod
    def preference_loss(outputs, preferences):
        """
        Preference-based loss for RL.
        
        Args:
            outputs (torch.Tensor): Model predictions
            preferences (torch.Tensor): User preference information
            
        Returns:
            torch.Tensor: Loss value
        """
        # Implement your specific preference-based loss logic
        # This is just a placeholder example
        preferred, non_preferred = outputs[::2], outputs[1::2]
        return -torch.mean(torch.log(torch.sigmoid(preferred - non_preferred)))

class CulumativeRewardLoss:
    """A collection of loss functions for cumulative reward-based training."""
    
    @staticmethod
    def mse_loss(predicted_return, step_rewards, gamma=0.9):
        """
        Mean squared error loss for cumulative rewards.
        
        Args:
            value_1 (torch.Tensor): Predicted cumulative rewards, it should be in the shape of (batch_size, horizon, 1)
            value_2 (torch.Tensor): Ground truth step-wise rewards, it should be in the shaope of (batch_size, horizon, 1)
            
        Returns:
            torch.Tensor: Loss value
        """
        ground_truth_culumative_rewards = calculate_culumative_rewards(step_rewards, gamma)
        return F.mse_loss(predicted_return, ground_truth_culumative_rewards)

class WeightedPolicyLoss:
    """A collection of loss functions for weighted policy-based training."""
    
    @staticmethod
    def mse_loss(predicted_actions, true_actions, returns):
        """
        Mean squared error loss for weighted policies.
        
        Args:
            value_1 (torch.Tensor): Predicted values
            value_2 (torch.Tensor): Ground truth values
            
        Returns:
            torch.Tensor: Loss value
        """
        breakpoint() # Check the shapes
        loss = F.mse_loss(predicted_actions, true_actions)
        return (loss * returns.unsqueeze(1)).mean()
    
    @staticmethod
    def bce_loss(predicted_actions, true_actions, returns):
        """
        Binary cross entropy loss for weighted policies.
        
        Args:
            value_1 (torch.Tensor): Predicted values
            value_2 (torch.Tensor): Ground truth values
            
        Returns:
            torch.Tensor: Loss value
        """
        # print(f"Max Policy Weight: {max(returns)}, Min Policy Weight: {min(returns)}")
        # print(f"Max Policy Weight: {returns.max()}, Min Policy Weight: {returns.min()}")
        loss = F.cross_entropy(predicted_actions, true_actions, reduction="none")
        # print(f"Loss before wieghting: {loss.mean()}")
        return (loss * returns).mean()
    
class PreferenceLoss:
    """A collection of loss functions for preference-based training."""
    
    @staticmethod
    def loglikelyhood_loss(value_1, value_2):
        """

        Args:
            value_1 (Torch.tensor): The shape of the tensor is (batch_size, 1).
            value_2 (Torch.tensor): The shape of the tensor is (batch_size, 1).

        Returns:
            Torch.tensor: The shape of the tensor is (1,)
        """
        # A better way to implement the loss function, to make it more stable.
        
        return torch.mean(F.softplus(value_2 - value_1)) 