import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MultiTaskLoss(nn.Module):
    def __init__(self):
        super(MultiTaskLoss, self).__init__()
        self.log_sigma1 = nn.Parameter(torch.tensor(0.0))  # log(σ1^2)
        self.log_sigma2 = nn.Parameter(torch.tensor(0.0))  # log(σ2^2)
    
    def forward(self, task1_loss, task2_loss):
        # Convert log variance to variance
        sigma1_sq = torch.exp(self.log_sigma1)
        sigma2_sq = torch.exp(self.log_sigma2)
        
        loss = (1 / (2 * sigma1_sq)) * task1_loss + (1 / (2 * sigma2_sq)) * task2_loss
        loss += self.log_sigma1 / 2 + self.log_sigma2 / 2  # Regularizer
        return loss
    

def compute_value_r2(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    """
    Compute R² score for value function predictions.
    R² = 1 - SSres/SStot
    where SSres = Σ(y_true - y_pred)² and SStot = Σ(y_true - mean(y_true))²
    """
    y_true_mean = y_true.mean()
    ss_tot = torch.sum((y_true - y_true_mean) ** 2)
    ss_res = torch.sum((y_true - y_pred) ** 2)
    r2 = 1 - (ss_res / (ss_tot + 1e-8))  # Add small epsilon to avoid division by 0
    return r2.item()


class AdaptiveLossWeights:
    def __init__(
        self,
        il_initial=0.8,
        il_final=0.1,
        rl_initial=0.2,
        rl_final=1.0,
        warmup_steps=5,
        decay_steps=50,
        min_value_quality=0.5  # Minimum R² value for value function
    ):
        self.il_initial = il_initial
        self.il_final = il_final 
        self.rl_initial = rl_initial
        self.rl_final = rl_final
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps
        self.min_value_quality = min_value_quality
        
    def get_weights(self, iteration: int, value_r2: float):
        """Get IL and RL weights based on training progress and value quality"""
        # Warm-up phase: heavily rely on IL
        if iteration < self.warmup_steps:
            frac = iteration / self.warmup_steps
            il_weight = self.il_initial
            rl_weight = self.rl_initial * frac
        # Decay phase: gradually shift from IL to RL
        else:
            progress = min((iteration - self.warmup_steps) / self.decay_steps, 1.0)
            il_weight = self.il_initial + progress * (self.il_final - self.il_initial)
            rl_weight = self.rl_initial + progress * (self.rl_final - self.rl_initial)
            
        # Scale RL weight by value function quality
        # value_quality = max(0, (value_r2 - self.min_value_quality) / (1 - self.min_value_quality))
        if value_r2 < self.min_value_quality:
            rl_weight *= value_r2
        
        return il_weight, rl_weight

def get_const_alpha(iteration):
    return 1.0

def get_linear_alpha(iteration, num_iterations, alpha_start=10, alpha_end=0.1):
    fraction = min(iteration / num_iterations, 1.0)
    # linearly go from alpha_start to alpha_end
    alpha = alpha_start + fraction * (alpha_end - alpha_start)
    return alpha
