import numpy as np
from sklearn.linear_model import LinearRegression, SGDRegressor
from utils import DataTransformer # Assuming utils.py is in online_setting or PYTHONPATH

def _is_sklearn_model_fitted_and_usable(model):
    if model is None:
        return False
    # Check for standard sklearn fitted attributes
    if not (hasattr(model, "coef_") and model.coef_ is not None and \
            hasattr(model, "intercept_") and model.intercept_ is not None):
        return False

    # Check intercept for NaNs
    if isinstance(model.intercept_, np.ndarray):  # SGD, multi-target LR
        if np.any(np.isnan(model.intercept_)):
            return False
    elif np.isnan(model.intercept_):  # Scalar intercept (LR single target)
        return False
    
    # Check coefficients for NaNs
    # model.coef_ can be an empty array for SGD if no features yet (fitted on 0 features), which is fine.
    # np.any(np.isnan([])) is False.
    if np.any(np.isnan(model.coef_)):
        return False
    
    return True

def _create_and_fit_initial_model(
    model_config, X_all_features_for_model, Y_all, tinit, transformer, random_seed
):
    model_type = model_config["model_type"]
    params = model_config.get("params", {})

    X_init_full_orig = X_all_features_for_model[:tinit]
    Y_init_full_orig = Y_all[:tinit]
    
    log_transform_feasible = True
    if transformer.use_log_transform:
        selected_X_init_for_log_check = transformer._select_features(X_init_full_orig)
        if np.any(selected_X_init_for_log_check <= 0) or np.any(Y_init_full_orig <= 0):
            log_transform_feasible = False

    X_init_transformed = transformer.transform_features(X_init_full_orig)
    Y_init_transformed = transformer.transform_target(Y_init_full_orig)

    min_samples_needed = (X_init_transformed.shape[1] + 1) if X_init_transformed.ndim > 1 and X_init_transformed.shape[1] > 0 else 2
    if transformer.num_lags > 0: 
        min_samples_needed = max(min_samples_needed, transformer.num_lags + 1)

    model = None
    if model_type == "ar":
        model = LinearRegression(fit_intercept=True)
        if log_transform_feasible and X_init_transformed.shape[0] >= min_samples_needed and X_init_transformed.shape[1] > 0:
            model.fit(X_init_transformed, Y_init_transformed)
    
    elif model_type == "online_sgd":
        model = SGDRegressor(
            warm_start=True, average=False, random_state=random_seed, **params.get("sgd_params", {})
        )
        if log_transform_feasible and X_init_transformed.shape[0] > 0 and Y_init_transformed.size > 0:
            model.partial_fit(X_init_transformed, Y_init_transformed.ravel())

    elif model_type == "rolling_lm":
        model = LinearRegression(fit_intercept=True)
        window = params.get("window_size", 100)
        start_idx = max(0, tinit - window)
        
        X_rolling_init_orig = X_all_features_for_model[start_idx:tinit]
        Y_rolling_init_orig = Y_all[start_idx:tinit]

        log_transform_feasible_rolling = True
        if transformer.use_log_transform:
            selected_X_rolling_for_log_check = transformer._select_features(X_rolling_init_orig)
            if np.any(selected_X_rolling_for_log_check <= 0) or np.any(Y_rolling_init_orig <= 0):
                log_transform_feasible_rolling = False
        
        X_rolling_init_transformed = transformer.transform_features(X_rolling_init_orig)
        Y_rolling_init_transformed = transformer.transform_target(Y_rolling_init_orig)
        
        min_samples_rolling = (X_rolling_init_transformed.shape[1] + 1) if X_rolling_init_transformed.ndim > 1 and X_rolling_init_transformed.shape[1] > 0 else 2
        
        if log_transform_feasible_rolling and (tinit > start_idx) and \
           X_rolling_init_transformed.shape[0] >= min_samples_rolling and X_rolling_init_transformed.shape[1] > 0:
            model.fit(X_rolling_init_transformed, Y_rolling_init_transformed)
    else:
        raise ValueError(f"Unknown model type during initial fit: {model_type}")
    return model

def _get_prediction_from_model(model, model_type, X_pred_orig, transformer, Y_history_orig_for_fallback):
    pred_t = np.nan
    
    if not _is_sklearn_model_fitted_and_usable(model):
        if len(Y_history_orig_for_fallback) > 0:
            if transformer.use_log_transform and np.any(Y_history_orig_for_fallback <= 0):
                mean_val_to_inverse = np.mean(Y_history_orig_for_fallback)
            else:
                Y_history_maybe_transformed = transformer.transform_target(Y_history_orig_for_fallback)
                mean_val_to_inverse = np.mean(Y_history_maybe_transformed)
            pred_t = transformer.inverse_transform_prediction(mean_val_to_inverse)
        else:
            pred_t = 0.0
    else: 
        X_pred_transformed = transformer.transform_features(X_pred_orig)
        
        if X_pred_transformed.shape[1] == 0 and not (model_type == "online_sgd"):
            if len(Y_history_orig_for_fallback) > 0:
                if transformer.use_log_transform and np.any(Y_history_orig_for_fallback <= 0):
                    pred_transformed_fallback_val = np.mean(Y_history_orig_for_fallback)
                else:
                    Y_history_maybe_transformed = transformer.transform_target(Y_history_orig_for_fallback)
                    pred_transformed_fallback_val = np.mean(Y_history_maybe_transformed)
                pred_transformed = pred_transformed_fallback_val
            else:
                pred_transformed = 0.0
        else: 
            try:
                pred_transformed = model.predict(X_pred_transformed)[0]
            except Exception: 
                if len(Y_history_orig_for_fallback) > 0:
                    if transformer.use_log_transform and np.any(Y_history_orig_for_fallback <= 0):
                        pred_transformed_fallback_val = np.mean(Y_history_orig_for_fallback)
                    else:
                        Y_history_maybe_transformed = transformer.transform_target(Y_history_orig_for_fallback)
                        pred_transformed_fallback_val = np.mean(Y_history_maybe_transformed)
                    pred_transformed = pred_transformed_fallback_val
                else:
                    pred_transformed = 0.0
        
        pred_t = transformer.inverse_transform_prediction(pred_transformed)

    if np.isnan(pred_t):
        pred_t = np.mean(Y_history_orig_for_fallback) if len(Y_history_orig_for_fallback) > 0 else 0.0
    return pred_t

def _update_online_model(
    model, model_config, X_all_features_for_model, Y_all, 
    current_t_data_idx, current_t_pred_idx, transformer
):
    if model is None: 
        return

    model_type = model_config["model_type"]
    params = model_config.get("params", {})

    X_point_update_orig = X_all_features_for_model[current_t_data_idx, :].reshape(1, -1)
    Y_point_update_orig = np.array([Y_all[current_t_data_idx]])

    log_transform_feasible_for_point = True
    if transformer.use_log_transform:
        selected_X_point_for_log_check = transformer._select_features(X_point_update_orig)
        if np.any(selected_X_point_for_log_check <= 0) or np.any(Y_point_update_orig <= 0):
            log_transform_feasible_for_point = False

    if model_type == "ar":
        num_lags_model_refit = transformer.num_lags 
        if num_lags_model_refit > 0: 
            min_samples_for_refit = num_lags_model_refit + 1
            if (current_t_data_idx + 1) >= min_samples_for_refit:
                X_train_refit_orig = X_all_features_for_model[:current_t_data_idx+1, :] 
                Y_train_refit_orig = Y_all[:current_t_data_idx+1]
                
                log_feasible_for_ar_refit = True
                if transformer.use_log_transform:
                    selected_X_train_for_log = transformer._select_features(X_train_refit_orig)
                    if np.any(selected_X_train_for_log <= 0) or np.any(Y_train_refit_orig <= 0):
                        log_feasible_for_ar_refit = False
                
                if log_feasible_for_ar_refit:
                    X_train_transformed_refit = transformer.transform_features(X_train_refit_orig)
                    Y_train_transformed_refit = transformer.transform_target(Y_train_refit_orig)
                    if X_train_transformed_refit.shape[1] > 0: 
                        try:
                            model.fit(X_train_transformed_refit, Y_train_transformed_refit)
                        except Exception as e:
                            pass 
    
    elif model_type == "online_sgd":
        if log_transform_feasible_for_point: 
            X_up_t_transformed = transformer.transform_features(X_point_update_orig)
            Y_up_t_transformed = transformer.transform_target(Y_point_update_orig)
            if Y_up_t_transformed.size > 0: 
                try:
                    model.partial_fit(X_up_t_transformed, Y_up_t_transformed.ravel())
                except Exception as e:
                    pass

    elif model_type == "rolling_lm":
        retrain_freq = params.get("retrain_freq", 1)
        if (current_t_pred_idx + 1) % retrain_freq == 0:
            win = params.get("window_size", 100)
            st, en = max(0, (current_t_data_idx + 1) - win), current_t_data_idx + 1 
            
            X_roll_orig = X_all_features_for_model[st:en]
            Y_roll_orig = Y_all[st:en]

            log_feasible_for_rolling_refit = True
            if transformer.use_log_transform:
                selected_X_roll_for_log = transformer._select_features(X_roll_orig)
                if np.any(selected_X_roll_for_log <= 0) or np.any(Y_roll_orig <= 0):
                    log_feasible_for_rolling_refit = False

            if log_feasible_for_rolling_refit:
                X_roll_transformed = transformer.transform_features(X_roll_orig)
                Y_roll_transformed = transformer.transform_target(Y_roll_orig)
                min_s_rolling = (X_roll_transformed.shape[1] + 1) if X_roll_transformed.ndim > 1 and X_roll_transformed.shape[1] > 0 else 2
                
                if X_roll_transformed.shape[0] >= min_s_rolling and X_roll_transformed.shape[1] > 0:
                    try:
                        model.fit(X_roll_transformed, Y_roll_transformed)
                    except Exception as e:
                        pass

def generate_model_outputs(Y, X_dict, model_configs, tinit, random_seed):
    np.random.seed(random_seed) 
    num_models = len(model_configs)
    T = len(Y)
    n_pred_steps = T - tinit
    
    models = [None] * num_models 
    transformers = [None] * num_models 
    
    model_outputs_list = [
        {
            "model_name": config_m["name"],
            "raw_predictions": np.full(n_pred_steps, np.nan),
        }
        for config_m in model_configs
    ]

    for i, config_m in enumerate(model_configs):
        feature_set_key = config_m["feature_set"]
        X_current_fs_all = X_dict[feature_set_key] 
        
        transform_params = config_m.get("params", {}).get("data_transform_params", {})
        transformers[i] = DataTransformer(
            use_log_transform=transform_params.get("use_log_transform", False),
            num_lags=transform_params.get("num_lags", 0),
            feature_indices=transform_params.get("feature_indices", None)
        )
        
        models[i] = _create_and_fit_initial_model(
            config_m, X_current_fs_all, Y, tinit, transformers[i], random_seed
        )

    for t_pred_idx, t_data_idx in enumerate(range(tinit, T)):
        for i, config_m in enumerate(model_configs):
            feature_set_key = config_m["feature_set"]
            current_X_fs_all = X_dict[feature_set_key]

            pred_t_i = np.nan
            if models[i] is None or current_X_fs_all.shape[0] <= t_data_idx:
                pred_t_i = np.mean(Y[:t_data_idx]) if t_data_idx > 0 else 0.0
            else:
                X_pred_point_orig = current_X_fs_all[t_data_idx, :].reshape(1, -1)
                pred_t_i = _get_prediction_from_model(
                    models[i], 
                    config_m["model_type"],
                    X_pred_point_orig, 
                    transformers[i], 
                    Y[:t_data_idx] 
                )
            
            model_outputs_list[i]["raw_predictions"][t_pred_idx] = pred_t_i

            if models[i] is not None and current_X_fs_all.shape[0] > t_data_idx : 
                 _update_online_model(
                    models[i], 
                    config_m, 
                    current_X_fs_all, 
                    Y,            
                    t_data_idx,   
                    t_pred_idx,   
                    transformers[i]
                )
                
    return model_outputs_list 