import torch
import torch.nn as nn
import torch.nn.functional as F
import xgboost as xgb
import lightgbm as lgb
import catboost as cb
import numpy as np
from .fds_layer import FDSLayer

class MLP_DEPRECATED(nn.Module):
    """Deprecated MLP model. Use MLPadv instead."""
    
    def __init__(self, input_dim, hidden_dims, output_dim=1, dropout_rate=0.1,
                 use_fds=False, fds_config=None, fds_discretizer=None):
        """
        Args:
            input_dim: Number of input features.
            hidden_dims: List of hidden layer dimensions.
            output_dim: Output dimension.
            dropout_rate: Dropout rate.
            use_fds: Whether to use FDS layer.
            fds_config: FDS configuration dictionary.
            fds_discretizer: Fitted KBinsDiscretizer for FDS.
        """
        super(MLP_DEPRECATED, self).__init__()
        self.use_fds = use_fds and (FDSLayer is not None)
        self.feature_layers = nn.ModuleList()
        self.final_layer = None
        self.fds_layer = None
        self.current_epoch = 0

        last_dim = input_dim
        for i, hidden_dim in enumerate(hidden_dims):
            self.feature_layers.append(nn.Linear(last_dim, hidden_dim))
            self.feature_layers.append(nn.ReLU())
            self.feature_layers.append(nn.BatchNorm1d(hidden_dim))
            self.feature_layers.append(nn.Dropout(dropout_rate))
            last_dim = hidden_dim

        if self.use_fds:
            if fds_config is None or fds_discretizer is None:
                print("Warning: FDS disabled due to missing fds_config or fds_discretizer.")
                self.use_fds = False
            else:
                try:
                    self.fds_layer = FDSLayer(feature_dim=last_dim, fds_config=fds_config, discretizer=fds_discretizer)
                    print(f"MLP: FDS layer enabled (feature_dim={last_dim})")
                except Exception as e:
                    print(f"FDSLayer initialization error: {e}. FDS disabled.")
                    self.use_fds = False

        self.final_layer = nn.Linear(last_dim, output_dim)

    def forward(self, x, targets=None, epoch=None, return_features=False):
        """
        Forward pass.
        
        Args:
            x: Input tensor.
            targets: Target tensor (required for FDS during training).
            epoch: Current epoch (for FDS).
            return_features: Whether to return intermediate features.

        Returns:
            Output tensor, or tuple of (output, features/encoding) depending on settings.
        """
        features = x
        for layer in self.feature_layers:
            features = layer(features)

        encoding = features
        encoding_s = encoding
        
        if self.use_fds and self.fds_layer is not None and self.training:
            current_epoch = epoch if epoch is not None else self.current_epoch
            encoding_s = self.fds_layer(encoding_s, targets, current_epoch)

        output = self.final_layer(encoding_s)
        
        if return_features:
            return output, features
        elif self.training and self.use_fds:
            return output, encoding
        else:
            return output

    def update_fds_epoch(self):
        """Update FDS epoch counter."""
        self.current_epoch += 1


class MLPadv(nn.Module):
    """Advanced MLP with LayerNorm, GELU activation, Xavier initialization, and optional residual connections."""
    
    def __init__(self, input_dim, hidden_dims, output_dim=1, dropout_rate=0.1,
                 use_fds=False, fds_config=None, fds_discretizer=None, use_residual=True):
        """
        Args:
            input_dim: Number of input features.
            hidden_dims: List of hidden layer dimensions.
            output_dim: Output dimension.
            dropout_rate: Dropout rate.
            use_fds: Whether to use FDS layer.
            fds_config: FDS configuration dictionary.
            fds_discretizer: Fitted KBinsDiscretizer for FDS.
            use_residual: Whether to use residual connections.
        """
        super(MLPadv, self).__init__()
        self.use_fds = use_fds and (FDSLayer is not None)
        self.use_residual = use_residual and (input_dim != 1) and (output_dim == 1)
        self.feature_layers = nn.ModuleList()
        self.final_layer = None
        self.fds_layer = None
        self.current_epoch = 0

        last_dim = input_dim
        for i, hidden_dim in enumerate(hidden_dims):
            linear = nn.Linear(last_dim, hidden_dim)
            nn.init.xavier_uniform_(linear.weight)
            nn.init.zeros_(linear.bias)
            self.feature_layers.append(linear)
            self.feature_layers.append(nn.LayerNorm(hidden_dim))
            self.feature_layers.append(nn.GELU())
            self.feature_layers.append(nn.Dropout(dropout_rate))
            last_dim = hidden_dim

        if self.use_fds:
            if fds_config is None or fds_discretizer is None:
                print("Warning: FDS disabled due to missing fds_config or fds_discretizer.")
                self.use_fds = False
            else:
                try:
                    self.fds_layer = FDSLayer(feature_dim=last_dim, fds_config=fds_config, discretizer=fds_discretizer)
                    print(f"MLPadv: FDS layer enabled (feature_dim={last_dim})")
                except Exception as e:
                    print(f"FDSLayer initialization error: {e}. FDS disabled.")
                    self.use_fds = False

        self.final_layer = nn.Linear(last_dim, output_dim)
        nn.init.xavier_uniform_(self.final_layer.weight, gain=0.1)
        nn.init.zeros_(self.final_layer.bias)

        if self.use_residual:
            self.res_proj = nn.Linear(input_dim, output_dim)
            nn.init.xavier_uniform_(self.res_proj.weight, gain=0.1)
            nn.init.zeros_(self.res_proj.bias)

    def forward(self, x, targets=None, epoch=None, return_features=False):
        """
        Forward pass.
        
        Args:
            x: Input tensor.
            targets: Target tensor (required for FDS during training).
            epoch: Current epoch (for FDS).
            return_features: Whether to return intermediate features.

        Returns:
            Output tensor, or tuple of (output, features/encoding) depending on settings.
        """
        features = x
        for layer in self.feature_layers:
            features = layer(features)

        encoding = features
        encoding_s = encoding
        
        if self.use_fds and self.fds_layer is not None and self.training:
            current_epoch = epoch if epoch is not None else self.current_epoch
            encoding_s = self.fds_layer(encoding_s, targets, current_epoch)

        output = self.final_layer(encoding_s)
        
        if self.use_residual:
            output = output + 0.1 * self.res_proj(x)

        if return_features:
            return output, features
        elif self.training and self.use_fds:
            return output, encoding
        else:
            return output

    def update_fds_epoch(self):
        """Update FDS epoch counter."""
        self.current_epoch += 1


class SimpleThreeMLPEnsemble(nn.Module):
    """Ensemble of multiple MLPadv models with averaged predictions."""
    
    def __init__(self, input_dim, hidden_dims, output_dim=1, dropout_rate=0.1,
                 use_fds=False, fds_config=None, fds_discretizer=None, use_residual=True,
                 num_models=3):
        super().__init__()
        self.num_models = num_models
        self.models = nn.ModuleList([
            MLPadv(input_dim, hidden_dims, output_dim, dropout_rate,
                   use_fds=use_fds, fds_config=fds_config, fds_discretizer=fds_discretizer,
                   use_residual=use_residual)
            for _ in range(num_models)
        ])

    def forward(self, x, targets=None, epoch=None, return_features=False):
        """Forward pass that averages predictions from all models."""
        outputs = []
        
        for model in self.models:
            out = model(x, targets, epoch, return_features=False)
            if isinstance(out, tuple):
                outputs.append(out[0])
            else:
                outputs.append(out)
        
        avg_output = torch.stack(outputs, dim=0).mean(dim=0)
        return avg_output

    def update_fds_epoch(self):
        """Update FDS epoch counter for all models."""
        for model in self.models:
            model.update_fds_epoch()


# MLP is an alias for MLPadv
MLP = MLPadv


class XGBoostWrapper:
    """XGBoost wrapper for consistent interface."""
    
    def __init__(self, params):
        """
        Args:
            params: XGBoost hyperparameter dictionary.
        """
        self.params = params
        self.params['objective'] = 'reg:squarederror'
        self.params['eval_metric'] = 'mae'
        self.model = None

    def fit(self, X_train, y_train, X_val=None, y_val=None):
        """Train XGBoost model."""
        eval_set = [(X_train, y_train)]
        if X_val is not None and y_val is not None:
            eval_set.append((X_val, y_val))

        self.model = xgb.XGBRegressor(**self.params)
        self.model.fit(X_train, y_train, eval_set=eval_set, verbose=False)

    def predict(self, X):
        """Make predictions."""
        if self.model is None:
            raise Exception("Model not trained yet.")
        return self.model.predict(X)

class LightGBMWrapper:
    """LightGBM wrapper for consistent interface."""
    
    def __init__(self, params):
        """
        Args:
            params: LightGBM hyperparameter dictionary.
        """
        self.params = params
        self.params['objective'] = 'regression_l1'
        self.params['metric'] = 'mae'
        self.params['verbosity'] = -1
        self.params['random_state'] = params.get('random_state', 42)
        self.model = None

    def fit(self, X_train, y_train, X_val=None, y_val=None):
        """Train LightGBM model."""
        eval_set = [(X_val, y_val)] if X_val is not None and y_val is not None else None

        if hasattr(X_train, 'columns'):
            X_train.columns = ["".join(c if c.isalnum() else "_" for c in str(x)) for x in X_train.columns]
            if X_val is not None:
                X_val.columns = ["".join(c if c.isalnum() else "_" for c in str(x)) for x in X_val.columns]

        self.model = lgb.LGBMRegressor(**self.params)
        self.model.fit(X_train, y_train, eval_set=eval_set)

    def predict(self, X):
        """Make predictions."""
        if self.model is None:
            raise Exception("Model not trained yet.")
        if hasattr(X, 'columns'):
            X.columns = ["".join(c if c.isalnum() else "_" for c in str(x)) for x in X.columns]
        return self.model.predict(X)


class CatBoostWrapper:
    """CatBoost wrapper for consistent interface."""
    
    def __init__(self, params):
        """
        Args:
            params: CatBoost hyperparameter dictionary.
        """
        self.params = params
        self.params['loss_function'] = 'MAE'
        self.params['eval_metric'] = 'MAE'
        self.params['verbose'] = False
        self.params['allow_writing_files'] = False
        self.params['random_seed'] = params.get('random_state', 42)
        self.model = None

    def fit(self, X_train, y_train, X_val=None, y_val=None):
        """Train CatBoost model."""
        eval_set = [(X_val, y_val)] if X_val is not None and y_val is not None else None
        
        self.model = cb.CatBoostRegressor(**self.params)
        self.model.fit(X_train, y_train, eval_set=eval_set, use_best_model=True, verbose=False)

    def predict(self, X):
        """Make predictions."""
        if self.model is None:
            raise Exception("Model not trained yet.")
        return self.model.predict(X)

