"""
Reward Combiner Module

This module provides a flexible framework for combining multiple reward signals
using machine learning models: linear regression, gradient boosted decision trees, and MLPs.
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Any, Union, Optional
from abc import ABC, abstractmethod
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import GradientBoostingRegressor
import pickle
import os


class CombinationFunction(ABC):
    """
    Abstract base class for reward combination functions.
    
    This allows for easy extension and customization of how multiple
    reward signals are combined into a single reward value.
    """
    
    @abstractmethod
    def combine(self, rewards: Dict[str, float]) -> float:
        """
        Combine multiple reward values into a single reward.
        
        Args:
            rewards: Dictionary mapping objective names to reward values
            
        Returns:
            Combined reward value
        """
        pass
    
    @abstractmethod
    def fit(self, training_data: List[Dict[str, float]], targets: List[float]):
        """
        Train the combination function on training data.
        
        Args:
            training_data: List of reward dictionaries
            targets: List of target combined reward values
        """
        pass
    
    @abstractmethod
    def get_name(self) -> str:
        """Return the name of this combination function."""
        pass
    
    @abstractmethod
    def save(self, filepath: str):
        """Save the trained model to file."""
        pass
    
    @abstractmethod
    def load(self, filepath: str):
        """Load a trained model from file."""
        pass

    def get_active_objectives(self) -> List[str]:
        """
        Return list of objectives that have non-zero weights/coefficients.

        For models where we can identify zero-weight objectives, this returns
        only the objectives that contribute to the final combined score.
        Default implementation returns all objectives.
        """
        return self.objective_names


class LinearFunction(CombinationFunction):
    """
    Simple linear combination with manually specified weights: g(r_1, r_2, ...) = w_1*r_1 + w_2*r_2 + ... + bias
    
    This is a non-trainable version where you directly specify the weights for each objective.
    """
    
    def __init__(self, objective_names: List[str], weights: Dict[str, float], bias: float = 0.0):
        """
        Initialize with objective names and manually specified weights.
        
        Args:
            objective_names: List of objective names in consistent order
            weights: Dictionary mapping objective names to their weights
            bias: Optional bias term (default 0.0)
        """
        self.objective_names = objective_names
        self.weights = weights
        self.bias = bias
        self.is_fitted = True  # Always considered fitted since weights are manually specified
        
        # Validate that all objectives have weights
        if set(objective_names) != set(weights.keys()):
            missing = set(objective_names) - set(weights.keys())
            extra = set(weights.keys()) - set(objective_names)
            raise ValueError(f"Objective names and weights don't match. Missing: {missing}, Extra: {extra}")
    
    def combine(self, rewards: Dict[str, float]) -> float:
        """Combine rewards using the manually specified weights."""
        total = self.bias
        for objective in self.objective_names:
            if objective not in rewards:
                raise ValueError(f"Missing reward for objective: {objective}")
            total += self.weights[objective] * rewards[objective]
        return total
    
    def fit(self, training_data: List[Dict[str, float]], targets: List[float]):
        """
        No-op for LinearFunction since weights are manually specified.
        Included for interface compatibility.
        """
        print("Warning: LinearFunction uses manually specified weights and does not require fitting.")
        pass
    
    def get_name(self) -> str:
        return "linear_manual"
    
    def save(self, filepath: str):
        """Save the model configuration."""
        model_data = {
            'objective_names': self.objective_names,
            'weights': self.weights,
            'bias': self.bias,
            'is_fitted': self.is_fitted
        }
        with open(filepath, 'wb') as f:
            pickle.dump(model_data, f)
    
    def load(self, filepath: str):
        """Load the model configuration."""
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)
        
        self.objective_names = model_data['objective_names']
        self.weights = model_data['weights']
        self.bias = model_data['bias']
        self.is_fitted = model_data['is_fitted']
    
    def update_weights(self, new_weights: Dict[str, float], new_bias: float = None):
        """
        Update the weights after initialization.

        Args:
            new_weights: New dictionary of weights
            new_bias: Optional new bias value
        """
        if set(self.objective_names) != set(new_weights.keys()):
            raise ValueError("New weights must have the same objectives as the current configuration")
        self.weights = new_weights
        if new_bias is not None:
            self.bias = new_bias

    def get_active_objectives(self) -> List[str]:
        """Return objectives with non-zero weights."""
        return [name for name in self.objective_names if self.weights.get(name, 0.0) != 0.0]


class LinearRegressionFunction(CombinationFunction):
    """
    Linear regression combination: g(r_1, r_2, ...) = w_0 + w_1*r_1 + w_2*r_2 + ...
    """
    
    def __init__(self, objective_names: List[str]):
        """
        Initialize with objective names to ensure consistent ordering.
        
        Args:
            objective_names: List of objective names in consistent order
        """
        self.objective_names = objective_names
        self.model = LinearRegression(positive=True)
        # self.model = LinearRegression(positive=False)
        # Using non-negative least squares to ensure non-negative weights
        # Reason: Negative coefficients should not make sense since objectives discovery prompt asks for positive trends, i.e. higher is better.
        # Multicollinearity between objectives may lead to negative weights otherwise.
        self.is_fitted = False
    
    def combine(self, rewards: Dict[str, float]) -> float:
        """Combine rewards using the trained linear regression model."""
        if not self.is_fitted:
            raise ValueError("Model must be fitted before combining rewards")
        
        # Convert to feature vector in consistent order
        features = np.array([rewards[name] for name in self.objective_names]).reshape(1, -1)
        # breakpoint()
        return float(self.model.predict(features)[0])
    
    def fit(self, training_data: List[Dict[str, float]], targets: List[float]):
        """Train the linear regression model."""
        # Convert training data to feature matrix
        X = np.array([[data[name] for name in self.objective_names] for data in training_data])
        y = np.array(targets)
        
        self.model.fit(X, y)
        self.is_fitted = True
    
    def get_name(self) -> str:
        return "linear_regression"
    
    def save(self, filepath: str):
        """Save the trained model."""
        model_data = {
            'model': self.model,
            'objective_names': self.objective_names,
            'is_fitted': self.is_fitted
        }
        with open(filepath, 'wb') as f:
            pickle.dump(model_data, f)
    
    def load(self, filepath: str):
        """Load a trained model."""
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)

        self.model = model_data['model']
        self.objective_names = model_data['objective_names']
        self.is_fitted = model_data['is_fitted']

    def get_active_objectives(self) -> List[str]:
        """Return objectives with non-zero coefficients."""
        if not self.is_fitted:
            return self.objective_names
        return [name for i, name in enumerate(self.objective_names) if self.model.coef_[i] != 0.0]


class GradientBoostingFunction(CombinationFunction):
    """
    Gradient boosted decision tree combination.
    """
    
    def __init__(self, objective_names: List[str], n_estimators: int = 100, max_depth: int = 3, learning_rate: float = 0.1):
        """
        Initialize with objective names and hyperparameters.
        
        Args:
            objective_names: List of objective names in consistent order
            n_estimators: Number of boosting stages
            max_depth: Maximum depth of individual trees
            learning_rate: Learning rate for boosting
        """
        self.objective_names = objective_names
        self.model = GradientBoostingRegressor(
            n_estimators=n_estimators,
            max_depth=max_depth,
            learning_rate=learning_rate,
            random_state=42
        )
        self.is_fitted = False
    
    def combine(self, rewards: Dict[str, float]) -> float:
        """Combine rewards using the trained gradient boosting model."""
        if not self.is_fitted:
            raise ValueError("Model must be fitted before combining rewards")
        
        # Convert to feature vector in consistent order
        features = np.array([rewards[name] for name in self.objective_names]).reshape(1, -1)
        return float(self.model.predict(features)[0])
    
    def fit(self, training_data: List[Dict[str, float]], targets: List[float]):
        """Train the gradient boosting model."""
        # Convert training data to feature matrix
        X = np.array([[data[name] for name in self.objective_names] for data in training_data])
        y = np.array(targets)

        self.model.fit(X, y)
        self.is_fitted = True
    
    def get_name(self) -> str:
        return "gradient_boosting"
    
    def save(self, filepath: str):
        """Save the trained model."""
        model_data = {
            'model': self.model,
            'objective_names': self.objective_names,
            'is_fitted': self.is_fitted
        }
        with open(filepath, 'wb') as f:
            pickle.dump(model_data, f)
    
    def load(self, filepath: str):
        """Load a trained model."""
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)
        
        self.model = model_data['model']
        self.objective_names = model_data['objective_names']
        self.is_fitted = model_data['is_fitted']


class MLPFunction(CombinationFunction):
    """
    Multi-layer perceptron (MLP) neural network combination.
    """
    
    def __init__(self, objective_names: List[str], hidden_sizes: List[int] = [64, 32], 
                 dropout_rate: float = 0.1, learning_rate: float = 0.001):
        """
        Initialize the MLP.
        
        Args:
            objective_names: List of objective names in consistent order
            hidden_sizes: List of hidden layer sizes
            dropout_rate: Dropout rate for regularization
            learning_rate: Learning rate for training
        """
        self.objective_names = objective_names
        self.hidden_sizes = hidden_sizes
        self.dropout_rate = dropout_rate
        self.learning_rate = learning_rate
        
        # Build the network
        self.input_size = len(objective_names)
        self.model = self._build_network()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        self.criterion = nn.MSELoss()
        self.is_fitted = False
    
    def _build_network(self) -> nn.Module:
        """Build the MLP network."""
        layers = []
        
        # Input layer
        prev_size = self.input_size
        
        # Hidden layers
        for hidden_size in self.hidden_sizes:
            layers.extend([
                nn.Linear(prev_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(self.dropout_rate)
            ])
            prev_size = hidden_size
        
        # Output layer
        layers.append(nn.Linear(prev_size, 1))
        
        return nn.Sequential(*layers)
    
    def combine(self, rewards: Dict[str, float]) -> float:
        """Combine rewards using the trained MLP."""
        if not self.is_fitted:
            raise ValueError("Model must be fitted before combining rewards")
        
        # Convert to tensor in consistent order
        features = torch.tensor([rewards[name] for name in self.objective_names], dtype=torch.float32).unsqueeze(0)
        
        self.model.eval()
        with torch.no_grad():
            output = self.model(features)
        
        return float(output.item())
    
    def fit(self, training_data: List[Dict[str, float]], targets: List[float], epochs: int = 200, batch_size: int = 32):
        """Train the MLP model."""
        # Convert training data to tensors
        X = torch.tensor([[data[name] for name in self.objective_names] for data in training_data], dtype=torch.float32)
        y = torch.tensor(targets, dtype=torch.float32).unsqueeze(1)
        
        # Create data loader
        dataset = torch.utils.data.TensorDataset(X, y)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        self.model.train()
        
        for epoch in range(epochs):
            total_loss = 0.0
            for batch_X, batch_y in dataloader:
                self.optimizer.zero_grad()
                
                outputs = self.model(batch_X)
                loss = self.criterion(outputs, batch_y)
                
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
            
            if epoch % 20 == 0:
                avg_loss = total_loss / len(dataloader)
                print(f"Epoch {epoch}, Average Loss: {avg_loss:.4f}")
        
        self.is_fitted = True
    
    def get_name(self) -> str:
        return "mlp"
    
    def save(self, filepath: str):
        """Save the trained model."""
        model_data = {
            'state_dict': self.model.state_dict(),
            'objective_names': self.objective_names,
            'hidden_sizes': self.hidden_sizes,
            'dropout_rate': self.dropout_rate,
            'learning_rate': self.learning_rate,
            'is_fitted': self.is_fitted
        }
        torch.save(model_data, filepath)
    
    def load(self, filepath: str):
        """Load a trained model."""
        model_data = torch.load(filepath, map_location='cpu')
        
        self.objective_names = model_data['objective_names']
        self.hidden_sizes = model_data['hidden_sizes']
        self.dropout_rate = model_data['dropout_rate']
        self.learning_rate = model_data['learning_rate']
        self.is_fitted = model_data['is_fitted']
        
        # Rebuild network with loaded parameters
        self.input_size = len(self.objective_names)
        self.model = self._build_network()
        self.model.load_state_dict(model_data['state_dict'])
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)


class RewardCombiner:
    """
    Main reward combiner class that orchestrates the combination of multiple rewards.
    
    This class provides a clean interface for combining multiple reward signals
    using customizable ML-based combination functions g(r_1, r_2, r_3, ...).
    """
    
    def __init__(self, 
                 combination_function: CombinationFunction,
                 normalize_inputs: bool = False,
                 normalize_outputs: bool = False,
                 input_range: tuple = (1.0, 10.0),
                 output_range: tuple = (0.0, 1.0)):
        """
        Initialize the RewardCombiner.
        
        Args:
            combination_function: The ML model g() to use for combining rewards
            normalize_inputs: Whether to normalize input rewards to [0,1] range
            input_range: Expected range of input reward values (min, max)
            output_range: Desired range of output reward values (min, max)
        """
        self.combination_function = combination_function
        self.normalize_inputs = normalize_inputs
        self.normalize_outputs = normalize_outputs
        self.input_range = input_range
        self.output_range = output_range
    
    def combine_rewards(self, objective_scores: Dict[str, float]) -> float:
        """
        Combine multiple objective scores into a single reward value.
        
        Args:
            objective_scores: Dictionary mapping objective names to their scores
            
        Returns:
            Combined reward value in the specified output range
        """
        # Normalize inputs if requested
        if self.normalize_inputs:
            print('WARNING: Normalizing inputs may lead to unexpected results if input_range is not set correctly.')
            normalized_scores = self._normalize_scores(objective_scores)
        else:
            normalized_scores = objective_scores.copy()
        
        # Use the combination function to compute the combined score
        combined_score = self.combination_function.combine(normalized_scores)
        
        # Scale to output range if needed
        if self.normalize_outputs and self.output_range != (0.0, 1.0):
            print('WARNING: Normalizing outputs to [0,1] and then scaling to output_range may lead to unexpected results.')
            combined_score = self._scale_to_output_range(combined_score)
        return combined_score
    
    def combine_batch_rewards(self, batch_objective_scores: List[Dict[str, float]]) -> torch.Tensor:
        """
        Combine rewards for a batch of objective scores.
        
        Args:
            batch_objective_scores: List of dictionaries, each mapping objective names to scores
            
        Returns:
            Tensor of combined reward values with shape (batch_size,)
        """
        batch_size = len(batch_objective_scores)
        combined_rewards = torch.zeros(batch_size, dtype=torch.float32)
        
        for i, objective_scores in enumerate(batch_objective_scores):
            combined_rewards[i] = self.combine_rewards(objective_scores)
        
        return combined_rewards
    
    def fit(self, training_data: List[Dict[str, float]], targets: List[float], **kwargs):
        """
        Train the combination function on training data.
        
        Args:
            training_data: List of objective score dictionaries
            targets: List of target combined reward values
            **kwargs: Additional arguments passed to the combination function's fit method
        """
        # Normalize training data if requested
        if self.normalize_inputs:
            normalized_training_data = [self._normalize_scores(data) for data in training_data]
        else:
            normalized_training_data = training_data
        
        # Scale targets to [0,1] range if using output scaling
        if self.output_range != (0.0, 1.0):
            normalized_targets = [(t - self.output_range[0]) / (self.output_range[1] - self.output_range[0]) 
                                 for t in targets]
        else:
            normalized_targets = targets
        
        self.combination_function.fit(normalized_training_data, normalized_targets, **kwargs)
    
    def _normalize_scores(self, scores: Dict[str, float]) -> Dict[str, float]:
        """Normalize scores from input_range to [0, 1]."""
        min_val, max_val = self.input_range
        range_val = max_val - min_val
        
        normalized = {}
        for objective, score in scores.items():
            # Clamp to input range and normalize
            clamped_score = max(min_val, min(max_val, score))
            normalized[objective] = (clamped_score - min_val) / range_val
        
        return normalized
    
    def _scale_to_output_range(self, score: float) -> float:
        """Scale normalized score to the desired output range."""
        min_out, max_out = self.output_range
        return min_out + score * (max_out - min_out)
    
    def update_combination_function(self, new_function: CombinationFunction):
        """Update the combination function g()."""
        self.combination_function = new_function
    
    def save(self, filepath: str):
        """Save the trained combiner."""
        # Save the combination function
        model_path = filepath + "_model.pkl"
        self.combination_function.save(model_path)
        
        # Save combiner configuration
        config = {
            'combination_function_type': self.combination_function.get_name(),
            'normalize_inputs': self.normalize_inputs,
            'input_range': self.input_range,
            'output_range': self.output_range
        }
        
        with open(filepath + "_config.pkl", 'wb') as f:
            pickle.dump(config, f)
    
    def load(self, filepath: str, combination_function: CombinationFunction):
        """Load a trained combiner."""
        # Load combination function
        model_path = filepath + "_model.pkl"
        combination_function.load(model_path)
        self.combination_function = combination_function
        
        # Load combiner configuration
        with open(filepath + "_config.pkl", 'rb') as f:
            config = pickle.load(f)
        
        self.normalize_inputs = config['normalize_inputs']
        self.input_range = config['input_range']
        self.output_range = config['output_range']
    
    def get_info(self) -> Dict[str, Any]:
        """Get information about the current configuration."""
        return {
            "combination_function": self.combination_function.get_name(),
            "normalize_inputs": self.normalize_inputs,
            "input_range": self.input_range,
            "output_range": self.output_range,
            "is_fitted": self.combination_function.is_fitted
        }

    def get_active_objectives(self) -> List[str]:
        """Return list of objectives with non-zero weights from the combination function."""
        return self.combination_function.get_active_objectives()


# Convenience functions for creating combiners
def create_linear_combiner(objective_names: List[str], weights: Dict[str, float], bias: float = 0.0) -> RewardCombiner:
    """
    Create a RewardCombiner with manually specified linear weights.
    
    Args:
        objective_names: List of objective names
        weights: Dictionary mapping objective names to their weights
        bias: Optional bias term (default 0.0)
    
    Returns:
        RewardCombiner with LinearFunction
    """
    return RewardCombiner(
        combination_function=LinearFunction(objective_names, weights, bias)
    )

def create_linear_regression_combiner(objective_names: List[str]) -> RewardCombiner:
    """Create a RewardCombiner with trainable linear regression."""
    return RewardCombiner(
        combination_function=LinearRegressionFunction(objective_names)
    )

def create_gradient_boosting_combiner(objective_names: List[str], 
                                    n_estimators: int = 100, 
                                    max_depth: int = 3, 
                                    learning_rate: float = 0.1) -> RewardCombiner:
    """Create a RewardCombiner with gradient boosting."""
    return RewardCombiner(
        combination_function=GradientBoostingFunction(
            objective_names, n_estimators, max_depth, learning_rate
        )
    )

def create_mlp_combiner(objective_names: List[str], 
                       hidden_sizes: List[int] = [64, 32], 
                       dropout_rate: float = 0.1, 
                       learning_rate: float = 0.001) -> RewardCombiner:
    """Create a RewardCombiner with MLP."""
    return RewardCombiner(
        combination_function=MLPFunction(
            objective_names, hidden_sizes, dropout_rate, learning_rate
        )
    )

def create_reward_combiner(combiner_type: str,
                          objective_names: List[str],
                          manual_weights: Dict[str, float] = None,
                          manual_bias: float = 0.0,
                          **kwargs) -> RewardCombiner:
    """
    Factory function to create the appropriate RewardCombiner based on type.
    
    Args:
        combiner_type: Type of combiner ('linear', 'linear_regression', 'gradient_boosting', 'mlp')
        objective_names: List of objective names
        manual_weights: For 'linear' type, dictionary of manual weights
        manual_bias: For 'linear' type, bias term
        **kwargs: Additional arguments for specific combiner types
            - For gradient_boosting: n_estimators, max_depth, learning_rate
            - For mlp: hidden_sizes, dropout_rate, learning_rate
    
    Returns:
        RewardCombiner instance
    """
    if combiner_type == 'linear':
        if manual_weights is None:
            # Default equal weights if not specified
            manual_weights = {name: 1.0 / len(objective_names) for name in objective_names}
        return create_linear_combiner(objective_names, manual_weights, manual_bias)
    
    elif combiner_type == 'linear_regression':
        return create_linear_regression_combiner(objective_names)
    
    elif combiner_type == 'gradient_boosting':
        return create_gradient_boosting_combiner(
            objective_names,
            n_estimators=kwargs.get('n_estimators', 100),
            max_depth=kwargs.get('max_depth', 4),
            learning_rate=kwargs.get('learning_rate', 0.1)
        )
    
    elif combiner_type == 'mlp':
        return create_mlp_combiner(
            objective_names,
            hidden_sizes=kwargs.get('hidden_sizes', [64, 32]),
            dropout_rate=kwargs.get('dropout_rate', 0.1),
            learning_rate=kwargs.get('learning_rate', 0.001)
        )
    
    else:
        raise ValueError(f"Unknown combiner type: {combiner_type}. "
                        f"Choose from: 'linear', 'linear_regression', 'gradient_boosting', 'mlp'")