import math
from random import sample

import numpy as np
from k_means_constrained import KMeansConstrained
from sklearn.ensemble import RandomForestRegressor
from sklearn.kernel_ridge import KernelRidge


def train_model_feature(M, X_tr, Y_tr, number_partitions=-1):
    """
    Trains M Kernel Ridge models with:
    - Random uniform kernel selection (linear/rbf/laplacian)
    - Uniform regularization strength sampling (0.5-2)
    - Constrained equal-size clustering
    - Residual Random Forest models
    """
    n_samples = X_tr.shape[0]
    n1 = int(np.ceil(n_samples / 2))

    # Split data into prediction and validation sets
    X_p, Y_p = X_tr[:n1], Y_tr[:n1]
    X_v, Y_v = X_tr[n1:], Y_tr[n1:]

    clusters = []
    k = 0
    if number_partitions > 0:
        k = number_partitions
        n_p = X_p.shape[0]

        if n_p % k != 0:
            raise ValueError(f"Training samples ({n_p}) must be divisible by {k}")

        # Constrained k-means clustering
        kmeans = KMeansConstrained(
            n_clusters=k, size_min=n_p // k, size_max=n_p // k, verbose=False
        )
        kmeans.fit(X_p)

        clusters = [kmeans.labels_ == i for i in range(k)]

    elif number_partitions != -1:
        raise ValueError("number_partitions must be -1 or positive integer")

    Mmodels = []

    for _ in range(M):
        kernel = "rbf"
        kernel = np.random.choice(["rbf"])
        alpha = np.random.uniform(0.1, 0.4)

        # Cluster selection
        if number_partitions > 0:
            cluster_idx = np.random.randint(k)
            train_idx = clusters[cluster_idx]
            X_train = X_p[train_idx]
            Y_train = Y_p[train_idx]
        else:
            X_train, Y_train = X_p, Y_p

        # Train Kernel Ridge with selected parameters
        krr = KernelRidge(alpha=alpha, kernel=kernel).fit(X_train, Y_train)

        # Train residual model
        residuals = np.abs(Y_v - krr.predict(X_v))
        rf = RandomForestRegressor().fit(X_v, residuals)

        Mmodels.append((krr, rf))

    return Mmodels
