import pandas as pd, numpy as np, tensorflow as tf, os, joblib
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, PolynomialFeatures
from sklearn.metrics import (r2_score, mean_absolute_percentage_error,
                             mean_absolute_error, mean_squared_error, explained_variance_score)
from sklearn.inspection import permutation_importance
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from imblearn.over_sampling import SMOTE
from tensorflow.keras import (Model, Input, regularizers, optimizers, utils, callbacks, initializers)
from tensorflow.keras import backend as K
from tensorflow.keras.layers import (Dense, Dropout, BatchNormalization, Multiply, Add, concatenate)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.regularizers import l2
import matplotlib.pyplot as plt
import pydotplus
import gc
import random

# Set seeds for reproducibility
np.random.seed(0)
tf.random.set_seed(0)
os.environ['TF_DETERMINISTIC_OPS'] = '1'


class DataProcessor:
    def __init__(self):
        self.scaler = RobustScaler()
        self.y_scaler = RobustScaler()  # Scaler for target variable y
        self.poly = PolynomialFeatures(degree=2, interaction_only=True)
        self.feature_columns = ['Longitude', 'Latitude', 'Tem', 'Pre', 'rate',
                                'SOC', 'Clay', 'pH', 'BD', 'irrigation', 'data_source']
        self.categorical_cols = ['data_source']
        self.poly_feature_names = None

    def transform(self, X_df):
        X_df = X_df[self.feature_columns].values
        X_poly = self.poly.transform(X_df)
        return self.scaler.transform(X_poly)

    def process_data(self, small_df, large_df, noise):
        # Label data sources: 1 for small (high-quality), 0 for large (low-quality)
        small_df['data_source'] = 1
        large_df['data_source'] = 0
        self.feature_columns = [str(col) for col in self.feature_columns]
        large_df.dropna(inplace=True)  # Fixed: missing inplace=True

        # Split small dataset into train/test
        smalldf_train, smalldf_test = train_test_split(small_df, test_size=0.2, random_state=42)
        raw_X_train = smalldf_train.drop(target, axis=1)
        raw_y_train = smalldf_train[target].values
        X_test = smalldf_test.drop(target, axis=1)
        y_test = smalldf_test[target].values

        # Data augmentation via replication + noise
        replication_factor = 10
        X_rep = np.repeat(raw_X_train.values, replication_factor, axis=0)
        feature_stds = raw_X_train.std(axis=0).values
        feature_stds_scaled = feature_stds * noise
        col_indices = [self.feature_columns.index(col) for col in self.categorical_cols]
        noise_array = np.random.normal(
            loc=0,
            scale=feature_stds_scaled,
            size=X_rep.shape)
        noise_array[:, col_indices] = 0  # No noise for categorical features
        X_small = np.vstack([raw_X_train.values, X_rep + noise_array])
        y_small = np.concatenate([raw_y_train, np.repeat(raw_y_train, replication_factor)])

        # Polynomial feature expansion
        X_small_poly = self.poly.fit_transform(X_small)
        self.poly_feature_names = self.poly.get_feature_names_out(input_features=self.feature_columns)
        print(len(self.poly_feature_names))
        X_test_poly = self.poly.transform(X_test)
        X_large_poly = self.poly.transform(large_df.drop(target, axis=1))
        y_large = large_df[target].values

        # Feature scaling
        self.scaler.fit(X_small_poly)
        if len(self.feature_columns) != X_small.shape[1]:
            raise ValueError("Feature columns count mismatch after adding data_source")
        y_small_scaled = self.y_scaler.fit_transform(y_small.reshape(-1, 1)).flatten()
        y_large_scaled = self.y_scaler.transform(y_large.reshape(-1, 1)).flatten()
        y_test_scaled = self.y_scaler.transform(y_test.reshape(-1, 1)).flatten()
        return (
            self.scaler.transform(X_small_poly), y_small_scaled,
            self.scaler.transform(X_test_poly), X_test, y_test_scaled,
            self.scaler.transform(X_large_poly), y_large_scaled
        )

    def inverse_transform_y(self, y_scaled):
        """Inverse transform scaled predictions back to original scale"""
        return self.y_scaler.inverse_transform(y_scaled.reshape(-1, 1)).flatten()


def build_hybrid_model(input_dim, seed):
    input_layer = Input(shape=(input_dim,))
    print(input_layer)
    # Attention layer with L2 regularization and fixed initialization
    att = Dense(input_dim, activation='sigmoid', kernel_regularizer=l2(1e-4),
                kernel_initializer=GlorotUniform(seed=seed))(input_layer)
    x = Multiply()([input_layer, att])

    def residual_block(x, units):
        shortcut = x
        if shortcut.shape[-1] != units:
            # Adjust dimension with fixed initialization
            shortcut = Dense(units, kernel_regularizer=l2(1e-4), kernel_initializer=GlorotUniform(seed=seed + 1))(
                shortcut)

        # First dense layer
        x = Dense(units, activation='relu', kernel_regularizer=l2(1e-4),
                  kernel_initializer=GlorotUniform(seed=seed + 2))(x)
        x = BatchNormalization()(x)
        x = Dropout(0.5, seed=seed + 3)(x)  # Set Dropout seed
        # Second dense layer (linear activation)
        x = Dense(units, activation='linear', kernel_regularizer=l2(1e-4),
                  kernel_initializer=GlorotUniform(seed=seed + 4))(x)
        return Add()([shortcut, x])

    # Main dense layers
    x = Dense(512, activation='relu', kernel_regularizer=l2(1e-4), kernel_initializer=GlorotUniform(seed=seed + 5))(x)
    x = residual_block(x, 512)
    x = residual_block(x, 256)
    # Output layers
    main_output = Dense(1, activation='linear', name='main', kernel_initializer=GlorotUniform(seed=seed + 6))(
        x)
    aux_output = Dense(2, activation='softmax', name='aux', kernel_initializer=GlorotUniform(seed=seed + 7))(
        x)
    model = Model(inputs=input_layer, outputs=[main_output, aux_output])
    return model


class AdaptiveLossWeightCallback(tf.keras.callbacks.Callback):
    """
    Combines adaptive weight scheduling and weighted loss computation.
    Updates alpha at the start of each epoch and provides loss calculation.
    """

    def __init__(self, initial_alpha=0.4, alpha_min=0.4, alpha_max=0.8, alpha_growth_rate=1.005):
        super().__init__()
        self.initial_alpha = initial_alpha
        self.alpha_min = alpha_min
        self.alpha_max = alpha_max
        self.alpha_growth_rate = alpha_growth_rate
        self.aux_loss_scale = 1.0  # Added missing attribute

        self.alpha = tf.Variable(initial_alpha, trainable=False, dtype=tf.float32, name="adaptive_loss_alpha")
        self.history = {'alpha': []}

        self.mse = tf.keras.losses.MeanSquaredError()
        self.ce = tf.keras.losses.CategoricalCrossentropy()

    def on_epoch_begin(self, epoch, logs=None):
        current_alpha = self.alpha.numpy()
        new_alpha = np.clip(current_alpha * self.alpha_growth_rate, self.alpha_min, self.alpha_max)
        self.alpha.assign(new_alpha)
        self.history['alpha'].append(new_alpha)

    def compute_weighted_loss(self, y_true_main, y_pred_main, y_true_aux, y_pred_aux):
        main_loss = self.mse(y_true_main, y_pred_main)
        aux_loss = self.ce(y_true_aux, y_pred_aux)

        current_alpha = self.alpha.numpy()
        current_beta = 1.0 - current_alpha

        weighted_main_loss = current_alpha * main_loss
        weighted_aux_loss = current_beta * aux_loss * self.aux_loss_scale

        total_loss = weighted_main_loss + weighted_aux_loss
        return total_loss, main_loss, aux_loss, current_alpha

    def get_config(self):
        return {
            "initial_alpha": self.initial_alpha,
            "alpha_min": self.alpha_min,
            "alpha_max": self.alpha_max,
            "alpha_growth_rate": self.alpha_growth_rate,
            "aux_loss_scale": self.aux_loss_scale,
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)


class R2TestCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test, interval=50):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        self.interval = interval
        self.r2_scores = []

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.interval == 0:
            y_pred = self.model.predict(self.X_test, verbose=0)[0].flatten()
            current_r2 = r2_score(self.y_test, y_pred)
            self.r2_scores.append(current_r2)
            print(f"\nEpoch {epoch + 1}: Test R² = {current_r2:.4f}")

            with open(f'r2_log.txt', 'a') as f:
                f.write(f'Epoch {epoch + 1}: {current_r2:.4f}\n')


def drop_large(model, X_large, y_large, drop_probability_base=0.3):
    """
    Dynamically filter large dataset based on prediction error.
    Higher error → higher drop probability.
    """
    if len(X_large) == 0:
        return X_large, y_large, np.array([], dtype=int)

    try:
        y_pred_main_all = model(X_large, training=False)[0].numpy().flatten()
    except Exception as e:
        print(f"Error during prediction in drop_large: {e}")
        return X_large, y_large, np.arange(len(X_large))

    absolute_errors = np.abs(y_pred_main_all - y_large)
    current_mae = np.mean(absolute_errors)
    if current_mae <= 1e-12:
        current_mae = 1e-12
        print("Warning: Current MAE is very close to zero in drop_large. Using small epsilon.")

    normalized_errors = absolute_errors / current_mae
    drop_probabilities = drop_probability_base * (1.0 - np.exp(-normalized_errors))
    drop_probabilities = np.clip(drop_probabilities, 0.0, 0.99)
    np.random.seed(len(absolute_errors))
    random_vals = np.random.rand(len(absolute_errors))
    keep_mask = random_vals > drop_probabilities
    indices_kept = np.where(keep_mask)[0]
    filtered_X_large = X_large[indices_kept]
    filtered_y_large = y_large[indices_kept]

    print(f"  - Large samples before filtering: {len(X_large)}, after filtering: {len(filtered_X_large)} "
          f"(Avg MAE: {current_mae:.4f})")
    return filtered_X_large, filtered_y_large, indices_kept


class KerasModelWrapper:
    def __init__(self, keras_model):
        self.model = keras_model

    def predict(self, X):
        return self.model.predict(X, verbose=0)[0].flatten()

    def fit(self, X, y):
        pass


class EnsembleModel:
    def __init__(self, epoch=300, num_models=3, learning_rate=1e-6, drop_base=0.3):
        self.models = []
        self.num_models = num_models
        self.epoch = epoch
        self.learning_rate = learning_rate
        self.min = 0
        self.max = 1
        self.drop_base = drop_base

    def train_ensemble(self, X_small, y_small, X_test, y_test, X_large, y_large):
        input_dim = X_small.shape[1]
        self.models = [build_hybrid_model(input_dim, seed+44) for seed in range(self.num_models)]
        for i, model in enumerate(self.models):
            print(f"Training model {i + 1}/{self.num_models}")
            model = train_model(X_small, y_small, X_test, y_test, X_large, y_large, model, self.epoch,
                                self.learning_rate, drop_base=self.drop_base)
            model.save(f"saved_models/{target}/epoch-{self.epoch}/model_{i + 1}", save_format="tf")
        K.clear_session()
        gc.collect()

    def fit(self, X, y):
        pass

    def predict(self, X, raw_X):
        keras_preds = [model.predict(X)[0].flatten() for model in self.models]
        y_pred_scaled = np.median([*keras_preds], axis=0)
        y_pred = processor.inverse_transform_y(y_pred_scaled)
        return y_pred

    def compute_feature_importance_table(self, model, X_test, y_test, processor, n_repeats=5, random_state=42):
        """
        Compute and return feature importance as a DataFrame.
        Aggregates polynomial/interaction feature importance back to original features.
        """
        wrapped_model = KerasModelWrapper(model)

        result = permutation_importance(
            wrapped_model,
            X_test,
            y_test,
            n_repeats=n_repeats,
            random_state=random_state,
            n_jobs=1,
            scoring='neg_mean_absolute_error'
        )

        raw_importance = result.importances_mean
        feature_importance = np.zeros(len(processor.feature_columns))

        for idx, name in enumerate(processor.poly_feature_names):
            if '*' in name:
                parts = name.split('*')
                for part in parts:
                    if part in processor.feature_columns:
                        orig_idx = processor.feature_columns.index(part)
                        feature_importance[orig_idx] += raw_importance[idx] / len(parts)
            else:
                if name in processor.feature_columns:
                    orig_idx = processor.feature_columns.index(name)
                    feature_importance[orig_idx] += raw_importance[idx]

        min_imp = feature_importance.min()
        max_imp = feature_importance.max()
        if max_imp - min_imp > 1e-8:
            normalized_importance = (feature_importance - min_imp) / (max_imp - min_imp)
        else:
            normalized_importance = np.zeros_like(feature_importance)

        df_importance = pd.DataFrame({
            'Feature': processor.feature_columns,
            'Importance': feature_importance,
            'Importance_Normalized': normalized_importance
        }).sort_values(by='Importance_Normalized', ascending=False).reset_index(drop=True)

        return df_importance

    def save_predictions_to_excel(self, X_test, raw_X_test, y_test, output_path="predictions.xlsx"):
        """
        Save test set features, true values, predictions, and errors to Excel.
        """
        y_pred = self.predict(X_test, raw_X_test)
        y_test_original = processor.inverse_transform_y(y_test)

        df_result = raw_X_test.copy()
        df_result['True_Value'] = y_test_original
        df_result['Predicted_Value'] = y_pred
        df_result['Absolute_Error'] = df_result['True_Value'] - df_result['Predicted_Value']
        df_result['Relative_Error_(%)'] = np.where(
            df_result['True_Value'] != 0,
            (df_result['Absolute_Error'] / df_result['True_Value']) * 100,
            np.nan
        )

        df_result.to_excel(output_path, index=False, float_format="%.4f")
        print(f"✅ Predictions saved to: {output_path}")

    def visualize_performance(self, X_test, y_test, epoch):
        y_pred = self.predict(X_test, raw_X_test)

        plt.figure(figsize=(10, 5))
        # Prediction vs Truth
        plt.subplot(1, 3, 1)
        plt.scatter(y_test, y_pred, alpha=0.6)
        plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], 'r--')
        plt.xlabel('True Values')
        plt.ylabel('Predictions')
        plt.title(f'Prediction vs Truth (R²={r2_score(y_test, y_pred):.2f})')
        plt.grid(True)

        # Error distribution
        plt.subplot(1, 3, 2)
        errors = y_pred - y_test
        plt.hist(errors, bins=30)
        plt.xlabel('Prediction Error')
        plt.ylabel('Count')
        plt.title('Error Distribution')
        plt.grid(True)

        # Feature importance
        plt.subplot(1, 3, 3)
        wrapped_model = KerasModelWrapper(self.models[0])

        result = permutation_importance(
            wrapped_model,
            X_test,
            y_test,
            n_repeats=5,
            random_state=42,
            n_jobs=1,
            scoring='neg_mean_absolute_error'
        )
        raw_importance = result.importances_mean
        feature_importance = np.zeros(len(processor.feature_columns))
        for idx, name in enumerate(processor.poly_feature_names):
            if '*' in name:
                f1, f2 = name.split('*')
                if f1 in processor.feature_columns and f2 in processor.feature_columns:
                    idx1 = processor.feature_columns.index(f1)
                    idx2 = processor.feature_columns.index(f2)
                    feature_importance[idx1] += raw_importance[idx] / 2
                    feature_importance[idx2] += raw_importance[idx] / 2
            else:
                if name in processor.feature_columns:
                    idx_orig = processor.feature_columns.index(name)
                    feature_importance[idx_orig] += raw_importance[idx]

        feature_importance = (feature_importance - feature_importance.min()) / \
                             (feature_importance.max() - feature_importance.min() + 1e-8)
        plt.barh(processor.feature_columns, feature_importance)
        plt.title('Feature Importance')
        plt.tight_layout()

        plt.savefig(f'./training_process/model_performance_epoch{epoch}.png', dpi=300)
        plt.close()

        # Save feature importance table
        importance_df = self.compute_feature_importance_table(
            self.models[0],
            X_test,
            y_test,
            processor,
            random_state=42
        )
        importance_df.to_excel(f'feature_importance_epoch{epoch}.xlsx', index=False, float_format="%.6f")
        print(f"✅ Feature importance saved to: feature_importance_epoch{epoch}.xlsx")

        print("\nTop 5 Most Important Features:")
        print(importance_df.head())

    def save_models(self):
        for i, model in enumerate(self.models):
            model.save(f"saved_models/epoch-{self.epoch}/model_{i + 1}", save_format="tf")

    def load_models(self, base_path="saved_models"):
        self.models = []
        for i in range(self.num_models):
            model_path = os.path.join(base_path, f"model_{i + 1}")
            try:
                self.models.append(
                    tf.keras.models.load_model(
                        model_path,
                        custom_objects={'AdaptiveLossWeightCallback': AdaptiveLossWeightCallback},
                        compile=False
                    )
                )
                print(f"Loaded model {i + 1} from {model_path}")
            except Exception as e:
                print(f"Error loading model {i + 1}: {e}")


def train_model(X_small, y_small, X_test, y_test, X_large_initial, y_large_initial,
                model=None, epochs=300, rate=1e-6,
                drop_start_epoch=10, drop_interval=10, drop_base=0.3):
    """
    Train model with dynamic large-data filtering, adaptive loss weighting, and LR scheduling.
    Auxiliary task: classify data source (0=large, 1=small).
    """

    if model is None:
        print("Model not provided; building default hybrid model.")
        model = build_hybrid_model(X_small.shape[1], seed=42)

    optimizer = tf.keras.optimizers.Adam(learning_rate=rate)
    model.compile(
        optimizer=optimizer,
        loss={'main': 'mse', 'aux': 'categorical_crossentropy'},
        metrics={'main': ['mae'], 'aux': 'accuracy'},
    )

    # Train/validation split
    X_small_train, X_val, y_small_train, y_val = train_test_split(
        X_small, y_small, test_size=0.2, random_state=42
    )

    # Initialize large dataset
    X_large_current = X_large_initial.copy()
    y_large_current = y_large_initial.copy()
    print(f"Initial large dataset size: {len(X_large_current)}")

    # Auxiliary labels: 0 for small, 1 for large
    y_aux_train_small = to_categorical(np.zeros(len(X_small_train)), num_classes=2)
    y_aux_val = to_categorical(np.zeros(len(X_val)), num_classes=2)

    # Callbacks
    r2_callback = R2TestCallback(X_test, y_test, interval=10)
    early_stopping = EarlyStopping(monitor='val_main_mae', patience=30, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_main_mae', factor=0.5, patience=15, min_lr=1e-7, verbose=1)
    adaptive_weight_callback = AdaptiveLossWeightCallback()

    # Training history
    history = {
        'loss': [], 'main_loss': [], 'aux_loss': [],
        'main_mae': [], 'val_main_mae': [],
        'lr': [], 'alpha': []
    }

    # Metrics
    main_loss_fn = tf.keras.losses.MeanSquaredError()
    aux_loss_fn = tf.keras.losses.CategoricalCrossentropy()
    train_mae_metric = tf.keras.metrics.MeanAbsoluteError(name='train_mae')
    val_mae_metric = tf.keras.metrics.MeanAbsoluteError(name='val_main_mae')

    # Validation tensors
    X_val_tensor = tf.constant(X_val, dtype=tf.float32)
    y_val_main_tensor = tf.constant(y_val, dtype=tf.float32)
    y_val_aux_tensor = to_categorical(np.zeros(len(X_val)), num_classes=2)
    y_val_aux_tensor = tf.constant(y_val_aux_tensor, dtype=tf.float32)

    # Training loop
    best_val_main_mae = np.inf
    best_weights = None
    patience_counter = 0
    wait_for_lr_reduction = 0

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")

        # Update alpha
        adaptive_weight_callback.on_epoch_begin(epoch=epoch)
        current_alpha = adaptive_weight_callback.alpha.numpy()
        current_beta = (1 - current_alpha)

        # Combine datasets
        y_aux_train_large_current = to_categorical(np.ones(len(X_large_current)), num_classes=2)
        X_train_combined = np.vstack([X_small_train, X_large_current])
        y_train_combined_main = np.concatenate([y_small_train, y_large_current])
        y_train_combined_aux = np.vstack([y_aux_train_small, y_aux_train_large_current])

        # Shuffle
        indices = np.random.permutation(len(X_train_combined))
        X_train_combined = X_train_combined[indices]
        y_train_combined_main = y_train_combined_main[indices]
        y_train_combined_aux = y_train_combined_aux[indices]

        X_train_tensor = tf.constant(X_train_combined, dtype=tf.float32)
        y_train_main_tensor = tf.constant(y_train_combined_main, dtype=tf.float32)
        y_train_aux_tensor = tf.constant(y_train_combined_aux, dtype=tf.float32)

        # Training batches
        total_loss = 0
        num_batches = 0
        train_mae_metric.reset_states()
        batch_size = 1024

        for i in range(0, len(X_train_tensor), batch_size):
            x_batch = X_train_tensor[i:i + batch_size]
            y_main_batch = y_train_main_tensor[i:i + batch_size]
            y_aux_batch = y_train_aux_tensor[i:i + batch_size]

            with tf.GradientTape() as tape:
                main_pred, aux_pred = model(x_batch, training=True)
                main_loss_value = main_loss_fn(y_main_batch, main_pred)
                aux_loss_value = aux_loss_fn(y_aux_batch, aux_pred)
                weighted_main_loss = current_alpha * main_loss_value
                weighted_aux_loss = current_beta * aux_loss_value
                total_batch_loss = weighted_main_loss + weighted_aux_loss
            grads = tape.gradient(total_batch_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            total_loss += total_batch_loss
            train_mae_metric.update_state(y_main_batch, main_pred)
            num_batches += 1

        avg_epoch_loss = total_loss / num_batches

        # Validation
        val_main_pred, val_aux_pred = model(X_val_tensor, training=False)
        val_main_loss_value = main_loss_fn(y_val_main_tensor, val_main_pred)
        val_aux_loss_value = aux_loss_fn(y_val_aux_tensor, val_aux_pred)
        val_mae_metric.update_state(y_val_main_tensor, val_main_pred)
        current_val_main_mae = val_mae_metric.result().numpy()

        # Log history
        history['loss'].append(avg_epoch_loss.numpy())
        history['main_loss'].append(val_main_loss_value.numpy())
        history['aux_loss'].append(val_aux_loss_value.numpy())
        history['main_mae'].append(train_mae_metric.result().numpy())
        history['val_main_mae'].append(current_val_main_mae)
        history['lr'].append(optimizer.learning_rate.numpy())
        history['alpha'].append(current_alpha)

        print(f" - loss: {avg_epoch_loss:.4f} "
              f"- main_loss: {val_main_loss_value:.4f} "
              f"- aux_loss: {val_aux_loss_value:.4f} "
              f"- main_mae: {train_mae_metric.result():.4f} "
              f"- val_main_mae: {current_val_main_mae:.4f} "
              f"- lr: {optimizer.learning_rate.numpy():.2e} "
              f"- alpha: {current_alpha:.4f}")

        # Early stopping & LR reduction
        if current_val_main_mae < best_val_main_mae:
            best_val_main_mae = current_val_main_mae
            best_weights = [tf.Variable(w.numpy(), trainable=False) for w in model.trainable_weights]
            patience_counter = 0
            wait_for_lr_reduction = 0
        else:
            patience_counter += 1
            wait_for_lr_reduction += 1
            if wait_for_lr_reduction >= reduce_lr.patience:
                old_lr = optimizer.learning_rate.numpy()
                reduce_lr.on_epoch_end(epoch, logs={'val_main_mae': current_val_main_mae})
                new_lr = optimizer.learning_rate.numpy()
                if new_lr < old_lr:
                    print(f"    - Learning rate reduced to {new_lr:.2e}")
                wait_for_lr_reduction = 0

        if patience_counter >= early_stopping.patience:
            print(f"    - Early stopping at epoch {epoch + 1}")
            if best_weights is not None:
                for var, weight in zip(model.trainable_weights, best_weights):
                    var.assign(weight.numpy())
            break

        # Drop large samples periodically (only if R² > 0)
        if epoch >= drop_start_epoch and (epoch - drop_start_epoch) % drop_interval == 0:
            y_pred_test = model(X_test, training=False)[0].numpy().flatten()
            current_r2 = r2_score(y_test, y_pred_test)
            print(f"[Drop Large Check] Current Test R² = {current_r2:.4f}")

            if current_r2 > 0:
                print(f"\n[Drop Large] Filtering large dataset at epoch {epoch + 1} (R²={current_r2:.4f} > 0)...")
                filtered_X_large, filtered_y_large, kept_indices = drop_large(
                    model, X_large_current, y_large_current, drop_probability_base=drop_base)
                X_large_current = filtered_X_large
                y_large_current = filtered_y_large
                print(f"[Drop Large] Kept {len(kept_indices)} samples out of {len(X_large_initial)}.")
            else:
                print(f"[Drop Large Skipped] R²={current_r2:.4f} <= 0, skipping drop_large.")

        # R² logging
        if (epoch + 1) % r2_callback.interval == 0:
            y_pred_test = model(X_test, training=False)[0].numpy().flatten()
            current_r2 = r2_score(y_test, y_pred_test)
            r2_callback.r2_scores.append(current_r2)
            print(f"Epoch {epoch + 1}: Test R² = {current_r2:.4f}")
            try:
                with open(f'r2_log.txt', 'a') as f:
                    f.write(f'Epoch {epoch + 1}: {current_r2:.4f}\n')
            except Exception as e:
                print(f"Warning: Could not write R2 log: {e}")

    # Plot training curves
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.plot(history['main_mae'], label='Train MAE')
    plt.plot(history['val_main_mae'], label='Validation MAE')
    plt.title('MAE Progression')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 3, 2)
    plt.plot(history['loss'], label='Total Loss')
    plt.plot(history['main_loss'], label='Main Loss')
    plt.plot(np.array(history['aux_loss']), label='Aux Loss')
    plt.title('Loss Components')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 3, 3)
    plt.plot(history['alpha'], label='Alpha')
    plt.title('Alpha Parameter')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'training_process/training_curves_epoch{epochs}_custom.png', dpi=300)
    plt.close()

    # Save preprocessor
    try:
        joblib.dump(processor, f'./saved_models/preprocessor.pkl')
    except Exception as e:
        print(f"Warning: Could not save preprocessor: {e}")

    return model


if __name__ == "__main__":
    for target in ['Yield']:
        select_columns = ['Longitude', 'Latitude', 'Tem', 'Pre', 'rate', 'SOC', 'Clay', 'pH', 'BD', 'irrigation',
                          target]
        small_df = pd.read_excel("HA Data.xlsx", sheet_name=target).loc[:, select_columns]
        large_df = pd.read_excel("LA Data.xlsx").loc[:, select_columns]
        print(f"High-accuracy data: {len(small_df)} samples")
        print(f"Low-accuracy data: {len(large_df)} samples")
        processor = DataProcessor()
        with open("output.txt", "w", encoding="utf-8") as file:
            pass
        with open('r2_log.txt', 'w') as f:
            pass
        ensemble = EnsembleModel(num_models=1, learning_rate=1e-5, drop_base=1)
        X_small, y_small, X_test, raw_X_test, y_test, X_large, y_large = processor.process_data(
            small_df, large_df, noise=0.005)
        for epoch in [1000]:
            ensemble.epoch = epoch
            ensemble.train_ensemble(X_small, y_small, X_test, y_test, X_large, y_large)
            y_pred = ensemble.predict(X_test, raw_X_test)
            y_test_original = processor.inverse_transform_y(y_test)
            mae = mean_absolute_error(y_test_original, y_pred)
            mse = mean_squared_error(y_test_original, y_pred)
            ensemble.visualize_performance(X_test, y_test_original, epoch)
            ensemble.save_predictions_to_excel(X_test, raw_X_test, y_test,
                                               output_path=f"predictions_epoch{epoch}.xlsx")
            rmse = np.sqrt(mse)
            explained_variance = explained_variance_score(y_test_original, y_pred)
            r2 = r2_score(y_test_original, y_pred)
            mape = mean_absolute_percentage_error(y_test_original, y_pred)
            with open("output.txt", "a", encoding="utf-8") as file:
                file.write(f"---epochs:{epoch}---base:{1}---\n")
                file.write(f"MAE: {mae:.4f}\n")
                file.write(f"MSE: {mse:.4f}\n")
                file.write(f"RMSE: {rmse:.4f}\n")
                file.write(f"Explained Variance: {explained_variance:.4f}\n")
                file.write(f"R²: {r2:.4f}\n")
                file.write(f"MAPE: {mape:.4f}\n\n")