import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.neural_network import MLPRegressor

from scipy.spatial.distance import cityblock, canberra, euclidean

from sklearn.linear_model import SGDRegressor

from sklearn.svm import SVR

from SCM.TargetFunctions import *
import copy

import random
from abc import ABC, abstractmethod

class Mapping(ABC):
    """
    Abstract base class for all mappers.
    """
    def __init__(self):
        self.fitted = True
        self.current_mean = 0
        self.rho = 0.5
        self.u = 0
        self.std = 1
        self.lastNoise = 0.0
        self.functions = [LinearFunction, SineFunction, ThresholdFunction, RadialBasisFunction, CheckerboardFunction]
    
    @abstractmethod
    def map(self, X):
        """Map input X to output value."""
        raise NotImplementedError
    
    def get_current_mean(self) -> float:
        return self.current_mean
    
    @abstractmethod
    def drift(self, X, y, new_label_func=None):
        """Apply concept drift to the mapper."""
        raise NotImplementedError
    
    def is_fitted(self):
        """Check if the mapper is fitted."""
        return self.fitted
    
    @abstractmethod
    def generate_untrained_example(self, X):  
        raise NotImplementedError
    
    # @abstractmethod
    def drift_label_function(self):
        raise NotImplementedError
    
    @abstractmethod
    def fit(self, X, y):
        """Fit the mapper to data."""
        raise NotImplementedError
    
    def __str__(self):
        return self.__class__.__name__
    
class IncrementalMapping(Mapping):
    """
    Base class for mappers supporting incremental drift.
    Inherits from Mapping and adds incremental-specific logic.
    """
    def __init__(self):
        super().__init__()
        
    def reset_mean():
        raise NotImplementedError
    
    def get_current_mean(self):
        return self.current_mean
    
    def is_fitted(self):
        return self.fitted
    
    @abstractmethod
    def partial_fit(self, X, y):
        raise NotImplementedError
    
    def __str__(self):
        return self.__class__.__name__
    
class NormalMapper(Mapping):
    """A normal mapper that generates a random value from a normal distribution."""
    def __init__(self, mean=None, std=None, ewma_alpha=0.05, rho = 0.5):
        super().__init__()
        self.rho = rho
        self.fitted = True
        self.a = mean if mean is not None else np.random.randint(-20, 20)
        self.std = std if std is not None else np.random.randint(1, 10)
        self.train_has_started = False
        self.ewma_alpha = ewma_alpha
        self.dynamic_mean = self.a
        self.current_value = 0
        
    def map(self, _):
        """Maps the input to a value from a normal distribution."""
        self.dynamic_mean = (1 - self.ewma_alpha) * self.dynamic_mean + self.ewma_alpha * np.random.normal(self.a, self.std)
        self.lastNoise = self.rho * self.lastNoise + np.random.normal(self.u, self.std)
        return self.dynamic_mean + self.lastNoise
    
    def generate_untrained_example(self, X):
        """Generates an untrained example by mapping the input X."""
        return self.map(X)
    
    def drift(self):
        """Applies concept drift by changing the mean of the normal distribution."""
        self.a = np.random.uniform(-20, 20)
    
    def __str__(self):
        return "Normal Mapper"
    
    def fit(self, X, y):
        pass
    
class UniformMapper(Mapping):
    """A uniform mapper that generates values uniformly between two bounds."""
    def __init__(self, low=None, high=None, ewma_alpha=0.05, rho = 0.5):
        super().__init__()
        self.rho = rho
        self.fitted = True
        self.a = low if low is not None else np.random.randint(-20, 0)
        self.b = high if high is not None else np.random.randint(0, 20)
        self.train_has_started = False
        self.ewma_alpha = ewma_alpha
        self.center = (self.a + self.b) / 2
        self.current_value = 0
        
    def map(self, _):
        """Maps the input to a value from a uniform distribution."""
        noise = np.random.uniform(self.a, self.b)
        self.center = (1 - self.ewma_alpha) * self.center + self.ewma_alpha * noise
        self.lastNoise = self.rho * self.lastNoise + np.random.normal(self.u, self.std)
        return self.center + self.lastNoise
    
    def generate_untrained_example(self, X):      
        """Generates an untrained example by mapping the input X."""     
        return self.map(X)
    
    def drift(self):
        """Applies concept drift by changing the bounds of the uniform distribution."""
        self.a = np.random.randint(-20, 0)
        self.b = np.random.randint(0, 20)
    
    def __str__(self):
        return "Uniform Mapper"
    
    def fit(self, X, y):
        pass

class RandomMLPMapper(Mapping):
    """A random MLP mapper that generates a random MLP model."""
    def __init__(self, hidden_dims=(10, 10), activation='tanh', rho = 0.5):
        super().__init__()
        self.rho = rho
        self.fitted = False
        self.train_has_started = False
        self.label_function: None
        self.old_function = None
        self.a = None
        self.b = None

        # Random MLP
        self.hidden_dims = hidden_dims
        self.activation_fn = np.tanh if activation == 'tanh' else lambda x: np.maximum(0, x)

        self.weights = []
        self.biases = []
        self.input_dim = None # To be set during fitting

    def _initialize_random_mlp(self, input_dim, hidden_dims=(10,)):
        """Initializes a random MLP with Xavier initialization."""
        self.weights = []
        self.biases = []
        dims = [input_dim] + list(hidden_dims)
        for i in range(len(dims) - 1):
            limit = np.sqrt(6 / (dims[i] + dims[i+1]))
            w = np.random.uniform(-limit, limit, size=(dims[i], dims[i+1]))
            b = np.zeros(dims[i+1])
            self.weights.append(w)
            self.biases.append(b)
        limit = np.sqrt(6 / (dims[-1] + 1))
        w = np.random.uniform(-limit, limit, size=(dims[-1], 1))
        b = np.zeros(1)
        self.weights.append(w)
        self.biases.append(b)

    def _forward(self, X):
        """Forward pass through the MLP."""
        out = X
        for i in range(len(self.weights) - 1):
            out = self.activation_fn(out @ self.weights[i] + self.biases[i])
        return (out @ self.weights[-1] + self.biases[-1]).ravel()

    def map(self, X):
        """Maps the cause and effect relation from the parents nodes to this vertex."""
        if not self.fitted:
            self._initialize_random_mlp(input_dim=X.shape[1])
            self.fitted = True

        if X.shape[1] != self.n_parents:
            raise ValueError(f"Expected input with {self.n_parents} features, got {X.shape[1]}")
        
        self.lastNoise = self.rho * self.lastNoise + np.random.normal(self.u, self.std)
        return self._forward(X) + self.lastNoise

    def is_fitted(self):
        return self.fitted

    def fit(self, X=None, y=None):
        """Fit the mapper to data."""
        if X is not None:
            self._initialize_random_mlp(input_dim=X.shape[1])
            self.fitted = True
            self.n_parents = X.shape[1]

    def generate_untrained_example(self, X):
        """Generates an untrained example by mapping the input X."""
        if self.a is None or self.b is None:
            self.a = np.random.uniform(-1, 1, size=X.shape[0])
            self.b = np.random.uniform(-1, 1, size=X.shape[0])
            self.train_has_started = True
            self._initialize_random_mlp(input_dim=X.shape[0])
            self.n_parents = X.shape[0]
            self.fitted = True
        pred = self.map(X.reshape(1,-1))
        return float(pred)

    def drift(self, X=None, y=None, new_label_func=None):
        """Applies concept drift by reinitializing the network."""
        if X is not None:
            self._initialize_random_mlp(input_dim=X.shape[1])

    def __str__(self):
        return "Random MLP Mapper"

class MLPMapping(IncrementalMapping):
    """A mapping that uses a Multi-Layer Perceptron (MLP) for regression."""
    def __init__(self, rho = 0.5):
        super().__init__()
        self.rho = rho
        self.fitted = False
        self.train_has_started = False
        self.model = None
        self.a = None
        self.b = None
        self.label_function : TargetFunction = np.random.choice(self.functions)()
        self.old_function = None
        self.model = self._generate_model()
        
    def generate_untrained_example(self, X):
        """Generates an untrained example by mapping the input X to the target function."""
        if (self.a is None or self.b is None):
            self.a = np.random.uniform(-1, 1, size=X.shape[0])
            self.b = np.random.uniform(-1, 1, size=X.shape[0])
            self.train_has_started = True
          
        y = self.label_function.compute_function(X, self.a, self.b)
        return y
    
        
        
    def fit(self, X, y):
        """Fit the mapper to data."""
        self.n_parents = X.shape[1]
        self.model = self._generate_model()
        self.model.fit(X,y)
        self.fitted = True
        self.current_mean = np.mean(X)
        
    def is_fitted(self):
        return self.fitted
    
    def start_incremental_drift(self):
        self.drift_label_function()
    
    def partial_fit(self, X=None, y=None):
        """Incrementally fit the model to new data."""
        if not self.train_has_started:
            self.n_parents = X.shape[1]
            self.train_has_started = True
        if X is None or y is None:
            X = np.random.normal(0, 1, size=(1000, self.n_parents))
            y = np.array([self.label_function.compute_function(X[i], self.a, self.b) for i in range(X.shape[0])])
        self.model.partial_fit(X, y)
        self.fitted = True
        
    def map(self, X):
        """Maps the cause and effect relation from the parents nodes to this vertex."""
        if (not self.fitted):
            raise RuntimeError("Model not fitted.")
        if X.shape[1] != self.n_parents:
            raise ValueError(f"Expected input with {self.n_parents} features, got {X.shape[1]}")
        
        self.lastNoise = self.rho * self.lastNoise +  np.random.normal(self.u, self.std)
        return self.model.predict(X) + self.lastNoise
    
    def _generate_model(self):
        return MLPRegressor(hidden_layer_sizes=(10,), max_iter=10, solver='adam', learning_rate_init=0.001, warm_start=True)
    
    def drift_label_function(self, new_func=None):
        if new_func is None:
            new_func : TargetFunction = np.random.choice(self.functions)()
            while new_func.__str__() == self.label_function.__str__():
                new_func : TargetFunction = np.random.choice(self.functions)()
        self.old_function = copy.deepcopy(self.label_function)
        self.label_function = new_func
    
    def drift(self, X=None, y=None, new_label_func = None):
        self.drift_label_function(new_label_func)
        self.model = self._generate_model()
        self.fit(X, y)
        
    def __str__(self):
        return "MLP Mapper"        

class TreeMapper(Mapping):
    def __init__(self, rho = 0.5):
        super().__init__()
        self.rho = rho
        self.fitted = False
        self.a = None
        self.b = None
        self.train_has_started = False
        self.label_function : TargetFunction = np.random.choice(self.functions)()
        self.model = self._generate_model()
        
    def is_fitted(self):
        return self.fitted
    
    def generate_untrained_example(self, X):           
        if (self.a is None or self.b is None):
            self.a = np.random.uniform(-1, 1, size=X.shape[0])
            self.b = np.random.uniform(-1, 1, size=X.shape[0])
            self.train_has_started = True
          
        y = self.label_function.compute_function(X, self.a, self.b)
        return y
        
    def fit(self, X, y):
        self.n_parents = X.shape[1]
        if (not self.fitted):
            self.model = self._generate_model()
        
        self.model.fit(X,y)
        self.fitted = True
        self.current_mean = np.mean(X)
    
    def map(self, X):
        """Maps the cause and effect relation from the parent(s) vertex(ices) to this vertex."""
        if (not self.fitted):
            raise RuntimeError("Model not fitted.")
        if X.shape[1] != self.n_parents:
            raise ValueError(f"Expected input with {self.n_parents} features, got {X.shape[1]}")
        self.lastNoise = self.rho * self.lastNoise +  np.random.normal(self.u, self.std)
        return self.model.predict(X) + self.lastNoise
    
    def _generate_model(self):
        max_depth = np.random.randint(5, 25)
        return DecisionTreeRegressor(max_depth=max_depth)
    
    def drift_label_function(self, new_func=None):
        if new_func is None:
            new_func : TargetFunction = np.random.choice(self.functions)()
            while new_func.__str__() == self.label_function.__str__():
                new_func : TargetFunction = np.random.choice(self.functions)()
        self.label_function = new_func
    
    def drift(self, X=None, y=None, new_label_func=None):
        self.drift_label_function(new_func=new_label_func)
        if X is not None and y is not None:
            self.model = self._generate_model()
            self.model.fit(X, y)
        
    def __str__(self):
        return "Decision Tree Mapper"

class AbstractCategoricalMapper(Mapping):
    def __init__(self, min_classes=2, max_classes=20, embed=False):
        super().__init__()
        self.fitted = False
        self.n_parents = None
        self.K = None
        self.embed = embed
        self.embeddings = None
        self.min_classes = min_classes
        self.max_classes = max_classes
        self.class_swaps = {}
        self.n_classes = None

    def _sample_K(self):
        raw_k = int(np.round(np.random.gamma(2.0, 2.0))) + 2
        return np.clip(raw_k, self.min_classes, self.max_classes)

    def severe_drift(self):
        if self.K < 2:
            return
        class_a, class_b = np.random.choice(range(self.n_classes), size=2, replace=False)
        swap_a = self.class_swaps.get(class_a, class_a)
        swap_b = self.class_swaps.get(class_b, class_b)
        self.class_swaps[class_a] = swap_b
        self.class_swaps[class_b] = swap_a

    def drift_label_function(self):
        pass

    def sample_label(self):
        if hasattr(self, "class_weights") and self.class_weights is not None:
            probs = self.class_weights / self.class_weights.sum()
            return np.random.choice(self.K, p=probs)
        else:
            return np.random.randint(self.K)

class PrototypeCategoricalMapper(AbstractCategoricalMapper):
    def __init__(self, embed=False, min_classes=2, max_classes=20, distance="euclidean"):
        super().__init__(min_classes, max_classes, embed)
        self.prototypes = None
        self.parents_mean = None
        self.parents_count = 0
        self.distance = distance
        self.prototype_to_class = None

    def fit(self, X, y=None):
        if X is None:
            raise ValueError("X must not be None")

        self.n_parents = X.shape[1]
        
        self.n_classes = self._sample_K()
        self.K = np.random.randint(self.n_classes, self.max_classes+1)
        
        self.prototype_to_class = np.random.choice(self.n_classes, size = self.K)

        parent_mean = np.mean(X, axis=0)
        parent_std = np.std(X, axis=0)

        scaling_factor = 0.5

        self.prototypes = np.random.normal(
            loc=parent_mean,
            scale=scaling_factor * parent_std,
            size=(self.K, self.n_parents)
        )

        if self.embed:
            self.embeddings = np.random.normal(0, 1, size=(self.K, 4))

        self.fitted = True
        self.class_swaps = {}
        self.current_mean = np.mean(X)

    def is_fitted(self):
        return self.fitted
    
    def _initialize_centers(self, X):
        parent_mean = np.mean(X, axis=0)
        parent_std = np.std(X, axis=0)
        parent_std = np.where(parent_std < 1e-6, 1.0, parent_std)

        scaling_factor = 0.5

        self.prototypes = np.random.normal(
            loc=parent_mean,
            scale=scaling_factor * parent_std,
            size=(self.K, self.n_parents)
        )        
    
    def drift_label_function(self):
        pass

    def drift(self, X=None, y=None, new_label_func=None):
        num = np.random.rand()
        if num < 0.33:
            self._change_prototypes()
            return
        elif num < 0.66:
            self._change_distance()
            return
        elif X is not None:
            self._initialize_centers(X)
            return

        self._change_prototypes()

    def start_incremental_drift(self):
        pass

    def partial_fit(self, X=None, y=None, step_size=0.1):
        shifting_class = np.random.choice(range(self.K))
        shift_vector = np.random.normal(0, step_size, size=self.prototypes[shifting_class].shape)
        
        self.prototypes[shifting_class] += shift_vector*step_size
            
    def _change_prototypes(self):
        max_shift = 1.0
        drifting_class = np.random.choice(range(self.K))
        shift_vector = np.random.normal(0, max_shift, size=self.prototypes[drifting_class].shape)
        
        new_position = self.prototypes[drifting_class] + shift_vector
                
        self.prototypes[drifting_class] = new_position
        
        
    def _change_distance(self, new_distance=None):
        """Simulate concept drift by changing the distance function"""
        all_distances = ["euclidean", "manhattan", "canberra"]
        
        if new_distance is None:
            options = [d for d in all_distances if d != self.distance]
            self.distance = np.random.choice(options)
        else:
            if new_distance not in all_distances:
                raise ValueError("Unsupported distance type.")
            self.distance = new_distance

    def generate_untrained_example(self, X):
        X = X.flatten()

        if self.prototypes is None:
            self.n_parents = X.shape[0]
            self.n_classes = self._sample_K()
            self.K = np.random.randint(self.n_classes, self.max_classes+1)
            self.prototype_to_class = np.random.choice(self.n_classes, size=self.K)
            parent_mean = X
            parent_std = np.ones_like(parent_mean)

            scaling_factor = 0.5

            self.prototypes = np.random.normal(
                loc=parent_mean,
                scale=scaling_factor * parent_std,
                size=(self.K, self.n_parents)
            )

            if self.embed:
                self.embeddings = np.random.normal(0, 1, size=(self.K, 4))        

        dists = np.linalg.norm(self.prototypes - X, axis=1)
        idx = np.argmin(dists)
        
        self.prototypes[idx] = (self.prototypes[idx] + X) / 2

        return self.embeddings[idx] if self.embed else idx

    def map(self, X):
        if not self.fitted:
            raise RuntimeError("Model not fitted.")
        if X is None:
            raise ValueError("Cannot map from None input.")
        if X.shape[1] != self.n_parents:
            raise ValueError(f"Expected input with {self.n_parents} features, got {X.shape[1]}")

        dists = self._compute_distance(X, self.prototypes)
        idx = np.argmin(dists)
        class_idx = self.prototype_to_class[idx]

        # Apply severe drift swaps if any
        final_idx = self.class_swaps.get(class_idx, class_idx)

        return self.embeddings[final_idx] if self.embed else final_idx
        
    def _compute_distance(self, X, prototypes):
        if self.distance == "euclidean":
            return np.linalg.norm(prototypes - X, axis=1)
        elif self.distance == "manhattan":
            return np.sum(np.abs(prototypes - X), axis=1)
        elif self.distance == "canberra":
            return np.sum(np.abs(prototypes - X) / (np.abs(prototypes) + np.abs(X) + 1e-8), axis=1)
        else:
            raise ValueError(f"Unsupported distance: {self.distance}")
    
    def __str__(self):
        return "Categorical Mapper"
    
    def sample_label(self):
        return np.random.randint(self.K)

class SGDMapper(IncrementalMapping):
    def __init__(self, rho = 0.5):
        super().__init__()
        self.rho = rho
        self.fitted = False
        self.train_has_started = False
        self.model = None
        self.a = None
        self.b = None
        self.label_function : TargetFunction = np.random.choice(self.functions)()
        self.old_function = None
        self.output_mean = 0
        self.num_samples_seen = 0
        
    def generate_untrained_example(self, X):           
        if (self.a is None or self.b is None):
            self.a = np.random.uniform(-1, 1, size=X.shape[0])
            self.b = np.random.uniform(-1, 1, size=X.shape[0])
            self.train_has_started = True
          
        y = self.label_function.compute_function(X, self.a, self.b)
        return y
    
    def start_incremental_drift(self):
        self.drift_label_function()

    def partial_fit(self, X=None, y=None):
        if not self.train_has_started:
            self.model = self._generate_model()
            self.train_has_started = True

        if X is None or y is None:
            X = np.random.normal(0, 1, size=(1000, self.n_parents))
            y = np.array([self.label_function.compute_function(X[i], self.a, self.b) for i in range(X.shape[0])])

        self.model.partial_fit(X, y)
        
        
    def fit(self, X, y):
        self.n_parents = X.shape[1]
        if (not self.fitted):
            self.model = self._generate_model()
        
        self.model.fit(X,y)
        self.fitted = True
        
    def is_fitted(self):
        return self.fitted
    
    def reset_mean(self):
        self.output_mean = 0
        self.num_samples_seen = 0
        
    def map(self, X):
        """Maps the cause and effect relation from the parent(s) vertex(ices) to this vertex."""
        if (not self.fitted):
            raise RuntimeError("Model not fitted.")
        if X.shape[1] != self.n_parents:
            raise ValueError(f"Expected input with {self.n_parents} features, got {X.shape[1]}")
        
        self.lastNoise = self.rho * self.lastNoise +  np.random.normal(self.u, self.std)

        prediction = self.model.predict(X) + self.lastNoise

        new_value = prediction.mean() if isinstance(prediction, np.ndarray) else prediction
        self.num_samples_seen += 1
        self.output_mean += (new_value - self.output_mean) / self.num_samples_seen

        return prediction
    
    def _generate_model(self):
        return SGDRegressor(max_iter=10)
    
    def drift_label_function(self, new_func = None):
        if new_func is None:
            new_func : TargetFunction = np.random.choice(self.functions)()
            while new_func.__str__() == self.label_function.__str__():
                new_func : TargetFunction = np.random.choice(self.functions)()
        self.old_function = copy.deepcopy(self.label_function)
        self.label_function = new_func
    
    def drift(self, X=None, y=None, new_label_func = None):
        """Drifts the mapping function of the current vertex."""
        self.drift_label_function(new_label_func)
        if X is None or y is None:
            X = np.random.normal(0, 1, size=(1000,self.n_parents))
            y = np.random.normal(0, 1, 1000)        
        self.model = self._generate_model()
        self.model.fit(X, y)
        
    def __str__(self):
        return "SGD Regressor Mapper"
    

class OnlineGaussianCategoricalMapper(AbstractCategoricalMapper):
    def __init__(self, min_classes=2, max_classes=20, embed=False):
        super().__init__(min_classes, max_classes, embed)
        self.class_means = None
        self.class_vars = None
        self.class_counts = None
        self.component_to_class = None

    def fit(self, X, y=None):
        if X is None:
            raise ValueError("X must not be None")
        self.n_parents = X.shape[1]
        self.n_classes = self._sample_K()
        self.K = np.random.randint(self.n_classes, self.max_classes+1)
        
        self.component_to_class = np.random.choice(self.n_classes, size=self.K)

        self.class_means = np.random.normal(loc=np.mean(X, axis=0), scale=0.5, size=(self.K, self.n_parents))
        self.class_vars = np.ones((self.K, self.n_parents))
        self.class_counts = np.ones(self.K) 

        if self.embed:
            self.embeddings = np.random.normal(0, 1, size=(self.K, 4))
            
        self.class_swaps = {}

        self.fitted = True
        
    def _initialize_centers(self, X):
        self.class_means = np.random.normal(loc=np.mean(X, axis=0), scale=0.5, size=(self.K, self.n_parents))
        self.class_vars = np.ones((self.K, self.n_parents))
        self.class_counts = np.ones(self.K) 

    def generate_untrained_example(self, X):
        X = X.flatten()

        if self.class_means is None:
            self.n_parents = X.shape[0]
            self.n_classes = self._sample_K()
            self.K = np.random.randint(self.n_classes, self.max_classes+1)
            
            self.component_to_class = np.random.choice(self.n_classes, size=self.K)

            self.class_means = np.random.normal(loc=X, scale=0.5, size=(self.K, self.n_parents))
            self.class_vars = np.ones((self.K, self.n_parents))
            self.class_counts = np.ones(self.K)

            if self.embed:
                self.embeddings = np.random.normal(0, 1, size=(self.K, 4))

        class_likelihoods = np.zeros(self.n_classes)
        for k in range(self.K):
            var = np.maximum(self.class_vars[k], 1e-6)
            X_reshaped = X.reshape(-1, self.n_parents)
            exponent = -0.5 * np.sum((X_reshaped - self.class_means[k]) ** 2 / var, axis=1)
            coeff = -0.5 * np.sum(np.log(2 * np.pi * var))
            likelihood = np.exp(coeff + exponent)
            class_likelihoods[self.component_to_class[k]] += likelihood

        class_idx = np.argmax(class_likelihoods, axis=0)

        self.class_counts[class_idx] += 1
        alpha = 1.0 / self.class_counts[class_idx]

        old_mean = self.class_means[class_idx].copy()
        self.class_means[class_idx] = (1 - alpha) * self.class_means[class_idx] + alpha * X
        self.class_vars[class_idx] = (1 - alpha) * self.class_vars[class_idx] + alpha * (X - old_mean) ** 2

        return self.embeddings[class_idx] if self.embed else class_idx

    def map(self, X):
        if not self.fitted:
            raise RuntimeError("Model not fitted.")
        if X is None:
            raise ValueError("Cannot map from None input.")
        if X.shape[1] != self.n_parents:
            raise ValueError(f"Expected input with {self.n_parents} features, got {X.shape[1]}")

        class_likelihoods = np.zeros(self.n_classes)
        for k in range(self.K):
            var = np.maximum(self.class_vars[k], 1e-6)
            X_reshaped = X.reshape(-1, self.n_parents)
            exponent = -0.5 * np.sum((X_reshaped - self.class_means[k]) ** 2 / var, axis=1)
            coeff = -0.5 * np.sum(np.log(2 * np.pi * var))
            likelihood = np.exp(coeff + exponent)
            class_likelihoods[self.component_to_class[k]] += likelihood

        class_idx = np.argmax(class_likelihoods, axis=0)

        if np.isscalar(class_idx):
            final_idx = self.class_swaps.get(class_idx, class_idx)
        else:
            final_idx = np.array([self.class_swaps.get(i, i) for i in class_idx])

        return self.embeddings[final_idx] if self.embed else final_idx
    
    def drift_label_function(self):
        pass

    def drift(self, X=None, y=None, new_label_func=None):
        num = np.random.rand()
        if (num < 0.5):        
            max_shift = 1.0
            drifting_class = np.random.choice(range(self.K))
        
            shift_vector = np.random.normal(0, max_shift, size=self.class_means[drifting_class].shape)
            
            self.class_means[drifting_class] += shift_vector
        else:
            self._initialize_centers(X)        

    def start_incremental_drift(self):
        pass

    def partial_fit(self, X=None, y=None, step_size=0.01):
        shifting_class = np.random.choice(range(self.K))
        shift_vector = np.random.normal(0, step_size, size=self.class_means[shifting_class].shape)
        
        self.class_means[shifting_class] += shift_vector*step_size

    def __str__(self):
        return "Online Gaussian Categorical Mapper"

class RandomRBFCategoricalMapper(AbstractCategoricalMapper):
    def __init__(self, min_classes=2, max_classes=20, embed=False):
        super().__init__(min_classes, max_classes, embed)
        self.class_means = None
        self.radii = None
        self.class_counts = None
        self.component_to_class = None
        self.n_partial_fit_calls = 0

    def _sample_K(self):
        raw_k = int(np.round(np.random.gamma(2.0, 2.0))) + 2
        return np.clip(raw_k, self.min_classes, self.max_classes)

    def fit(self, X, y=None):
        if X is None:
            raise ValueError("X must not be None")
        self.n_parents = X.shape[1]
        self.n_classes = self._sample_K()
        self.K = np.random.randint(self.n_classes, self.max_classes+1)

        self.component_to_class = np.random.choice(self.n_classes, size=self.K)

        self.class_means = np.random.normal(loc=np.mean(X, axis=0), scale=0.5, size=(self.K, self.n_parents))
        self.radii = np.full(self.K, np.std(X))
        self.class_counts = np.ones(self.K)

        if self.embed:
            self.embeddings = np.random.normal(0, 1, size=(self.K, 4))

        self.class_swaps = {}
        self.fitted = True
        
    def _initialize_centers(self, X):
        self.class_means = np.random.normal(loc=np.mean(X, axis=0), scale=0.5, size=(self.K, self.n_parents))
        self.radii = np.full(self.K, np.std(X))
        self.class_counts = np.ones(self.K)

    def generate_untrained_example(self, X):
        X = X.flatten()

        if self.class_means is None:
            self.n_parents = X.shape[0]
            self.K = self._sample_K()
            self.class_means = np.random.normal(loc=X, scale=0.5, size=(self.K, self.n_parents))
            self.radii = np.ones(self.K)
            self.class_counts = np.ones(self.K)

            if self.embed:
                self.embeddings = np.random.normal(0, 1, size=(self.K, 4))

        dists = np.linalg.norm(self.class_means - X, axis=1)
        responses = np.exp(-dists**2 / (2 * self.radii**2))
        idx = np.argmax(responses)

        self.class_counts[idx] += 1
        alpha = 1.0 / self.class_counts[idx]

        old_mean = self.class_means[idx].copy()
        self.class_means[idx] = (1 - alpha) * self.class_means[idx] + alpha * X

        dist_to_mean = np.linalg.norm(X - old_mean)
        self.radii[idx] = (1 - alpha) * self.radii[idx] + alpha * dist_to_mean

        return self.embeddings[idx] if self.embed else idx
    
    def drift_label_function(self):
        pass

    def map(self, X):
        if not self.fitted:
            raise RuntimeError("Model not fitted.")
        if X is None:
            raise ValueError("Cannot map from None input.")
        if X.shape[1] != self.n_parents:
            raise ValueError(f"Expected input with {self.n_parents} features, got {X.shape[1]}")

        dists = []
        for k in range(self.K):
            X_reshaped = X.reshape(-1, self.n_parents)
            dist = np.linalg.norm(X_reshaped - self.class_means[k], axis=1)
            response = np.exp(-dist**2 / (2 * (self.radii[k]**2 + 1e-6)))
            dists.append(response)

        dists = np.array(dists)
        idx = np.argmax(dists, axis=0)
        class_idx = self.component_to_class[idx]
        final_idx = np.array([self.class_swaps.get(i, i) for i in class_idx])

        return self.embeddings[final_idx] if self.embed else final_idx

    def drift(self, X=None, y=None, new_label_func=None):
        """Abrupt drift: randomly shift a class center."""
        num = np.random.rand()
        if (num < 0.5):
            max_shift = 1.0
            drifting_class = np.random.choice(range(self.K))
            shift_vector = np.random.normal(0, max_shift, size=self.class_means[drifting_class].shape)
            self.class_means[drifting_class] += shift_vector
        else:
            self._initialize_centers(X)

    def start_incremental_drift(self):
        self.n_partial_fit_calls = 0
        if hasattr(self, "incremental_shift_vectors"):
            del self.incremental_shift_vectors

    def partial_fit(self, X=None, y=None, step_size=0.001):
        """Apply a small, incremental shift to a selected class."""
        if not hasattr(self, "incremental_shift_vectors"):
            self.incremental_shift_vectors = np.zeros_like(self.class_means)
            self.incremental_classes = range(self.K) 

        for c in self.incremental_classes:
            shift_vector = np.random.normal(0, 1, size=self.class_means[c].shape)
            shift_vector /= np.linalg.norm(shift_vector) 
            self.incremental_shift_vectors[c] = shift_vector * step_size

        for c in self.incremental_classes:
            self.class_means[c] += self.incremental_shift_vectors[c]
        self.n_partial_fit_calls += 1

    def save_concept(self):
        self._saved_class_means = self.class_means.copy()
        self._saved_radii = self.radii.copy()

    def restore_concept(self):
        if hasattr(self, "_saved_class_means"):
            self.class_means = self._saved_class_means.copy()
        if hasattr(self, "_saved_radii"):
            self.radii = self._saved_radii.copy()

    def __str__(self):
        return "RandomRBF Categorical Mapper"

class RotatingHyperplaneMapper(AbstractCategoricalMapper):
    def __init__(self, noise_rate=0.0, margin=0.0, rotation_speed=0.05, embed=False):
        super().__init__(min_classes=2, max_classes=2, embed=embed) 
        self.noise_rate = noise_rate
        self.margin = margin
        self.rotation_speed = rotation_speed
        self.fitted = False
        self.K = 2

    def _sample_K(self):
        return 2  # current version of rotating hyperplane only supports binary classification

    def fit(self, X, y=None):
        if X is None:
            raise ValueError("X must not be None")

        self.n_parents = X.shape[1]
        self.n_classes = 2

        w = np.random.normal(0, 1, size=self.n_parents)
        self.w = w / np.linalg.norm(w)
        self.bias = 0.0

        if self.embed:
            self.embeddings = np.random.normal(0, 1, size=(self.n_classes, 4))

        self.fitted = True

    def is_fitted(self):
        return self.fitted

    def generate_untrained_example(self, X):
        return 0

    def map(self, X):
        if not self.fitted:
            raise RuntimeError("Mapper not fitted.")
        if X is None:
            raise ValueError("Cannot map from None input.")
        if X.shape[1] != self.n_parents:
            raise ValueError(f"Expected input with {self.n_parents} features, got {X.shape[1]}")

        scores = X @ self.w - self.bias
        labels = (scores >= 0).astype(int)

        if self.margin > 0:
            mask = np.abs(scores) < self.margin
            labels[mask] = np.random.randint(0, 2, size=mask.sum())

        if self.noise_rate > 0:
            flip_mask = np.random.rand(len(labels)) < self.noise_rate
            labels[flip_mask] = 1 - labels[flip_mask]

        return self.embeddings[labels] if self.embed else labels

    def drift(self, X=None, y=None, new_label_func=None):
        """Abrupt drift: reinitialize the hyperplane randomly."""
        w = np.random.normal(0, 1, size=self.n_parents)
        self.w = w / np.linalg.norm(w)

    def start_incremental_drift(self):
        if hasattr(self, "rotation_axes"):
            del self.rotation_axes

    def partial_fit(self, step_size=0.05, _=None):
        """Incremental drift: rotate the hyperplane a bit."""
        if not hasattr(self, "rotation_axes"):
            self.rotation_axes = np.random.choice(self.n_parents, 2, replace=False)

        i, j = self.rotation_axes
        angle = self.rotation_speed * step_size
        R = np.eye(self.n_parents)
        R[i, i] = np.cos(angle)
        R[j, j] = np.cos(angle)
        R[i, j] = -np.sin(angle)
        R[j, i] = np.sin(angle)

        self.w = R @ self.w
        self.w /= np.linalg.norm(self.w)

    def save_concept(self):
        self._saved_w = self.w.copy()
        self._saved_bias = self.bias

    def restore_concept(self):
        if hasattr(self, "_saved_w"):
            self.w = self._saved_w.copy()
        if hasattr(self, "_saved_bias"):
            self.bias = self._saved_bias

    def __str__(self):
        return "Rotating Hyperplane Mapper"

    def sample_label(self):
        return np.random.randint(0, 2)
