import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from matplotlib.lines import Line2D

from matplotlib.colors import ListedColormap,BoundaryNorm
# Generate a simple 2D dataset with noise near the decision boundary
np.random.seed(52)

# Generate points for two classes
n_points = 60
class_1 = np.random.normal(loc=[2, 2], scale=0.8, size=(n_points, 2))
class_2 = np.random.normal(loc=[5, 5], scale=0.8, size=(n_points, 2))

# Combine points and labels
X = np.vstack((class_1, class_2))
y = np.hstack((np.zeros(n_points), np.ones(n_points)))

# Add asymmetric noise near the decision boundary
noise_ratio = 0.1

n_noisy_class_1 = int(noise_ratio * len(y) )  #  noise in class 1

class_0_indices = np.where(y == 0)[0]
class_1_indices = np.where(y == 1)[0]

# noisy_indices_class_0 = np.random.choice(class_0_indices, size=n_noisy_class_0, replace=False)
noisy_indices_class_1 = np.random.choice(class_1_indices, size=n_noisy_class_1, replace=False)

noisy_indices =    noisy_indices_class_1
y[noisy_indices] = 1 - y[noisy_indices]  # Flip labels

# Visualize the dataset with noisy labels
plt.figure(figsize=(8, 6))



pos = np.where(y==0)
plt.scatter(X[pos, 0], X[pos, 1],  marker='s', edgecolors='blue', s=100, facecolor='none',label='Class 0')

pos_1 = np.where(y == 1)
plt.scatter(X[pos_1, 0], X[pos_1, 1],  marker='D', edgecolors='red',s=100,facecolor='none',label='Class 1')


plt.legend()

plt.title('Simple 2D Dataset with Asymmetric Noisy Labels', fontsize=20)
plt.xlabel('Feature 1', fontsize=20)
plt.ylabel('Feature 2', fontsize=20)
plt.savefig('./visual_sgd_raw.pdf', bbox_inches='tight', pad_inches=0.02)
plt.show()

# Custom SGD with a dynamic sampler
class CustomSGD:
    def __init__(self, learning_rate=0.01, max_iter=1000):
        self.learning_rate = learning_rate
        self.max_iter = max_iter
        self.weights = None

    def fit(self, X, y,alg_type='sgd'):
        n_samples, n_features = X.shape
        self.weights = np.zeros(n_features + 1)  # Include bias term
        X_bias = np.c_[X, np.ones(n_samples)]  # Add bias term

        # Initialize uniform sampling probabilities
        sampling_probs = np.ones(n_samples) / n_samples

        for iteration in range(int(self.max_iter / n_samples)):
            # if iteration % 10 == 0 and iteration>400:
            #     print('iterations: ', iteration)
            #
            #     self.plot_decision_boundary_model(X, y)
            if iteration==100:
                self.learning_rate = self.learning_rate*0.2

            loss_sum = 0

            for j in range(n_samples):
                # Sample one example based on current sampling probabilities
                sampled_index = np.random.choice(n_samples, p=sampling_probs)
                x_sample = X_bias[sampled_index]
                y_sample = y[sampled_index]

                # Compute prediction and loss for the sampled point
                linear_output = np.dot(x_sample, self.weights)
                prediction = 1 / (1 + np.exp(-linear_output))  # Sigmoid
                loss = -(y_sample * np.log(prediction + 1e-9) + (1 - y_sample) * np.log(1 - prediction + 1e-9))

                # Gradient update based on sampled point
                gradient = (prediction - y_sample) * x_sample
                self.weights -= self.learning_rate * gradient
                loss_sum+=loss
            print('iterations:', iteration, 'loss= :', loss_sum/n_samples)

            if alg_type =='SGD-Q':


                # Update sampling probabilities based on current loss
                all_losses = -(
                        y * np.log(1 / (1 + np.exp(-np.dot(X_bias, self.weights))) ) +
                        (1 - y) * np.log(1 - 1 / (1 + np.exp(-np.dot(X_bias, self.weights))))
                    )


                q_loss = np.exp(-all_losses)
                sampling_probs =q_loss/np.sum(q_loss)

        return sampling_probs

    def predict(self, X):
        X_bias = np.c_[X, np.ones(X.shape[0])]  # Add bias term
        linear_output = np.dot(X_bias, self.weights)
        predictions = 1 / (1 + np.exp(-linear_output))  # Sigmoid
        return (predictions >= 0.5).astype(int)

    def plot_decision_boundary_model(self, X, y):
        x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
        y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
        xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
                             np.arange(y_min, y_max, 0.01))

        Z = self.predict(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)

        plt.contour(xx, yy, Z, colors='k', linestyles='dashed')
        plt.scatter(X[:, 0], X[:, 1], c=y,  cmap='coolwarm')# edgecolors='k',
        plt.title('Decision Boundary')
        plt.xlabel('Feature 1')
        plt.ylabel('Feature 2')
        plt.show()

# Train the custom SGD classifier

custom_sgd_uni = CustomSGD(learning_rate=0.01, max_iter=100000)
custom_sgd_uni.fit(X, y)

custom_sgd = CustomSGD(learning_rate=0.01, max_iter=100000)
p_sample= custom_sgd.fit(X, y,'SGD-Q')

markers = ['o' if y[i] == 0 else 's' for i in range(len(y))]

# Predict and compute accuracy
y_pred = custom_sgd.predict(X)
accuracy = accuracy_score(y, y_pred)
print(f"Accuracy: {accuracy * 100:.2f}%")

# Visualize the decision boundary
def plot_decision_boundary(clf,clf_uni, X, y):



    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
                         np.arange(y_min, y_max, 0.01))

    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    contour1 = plt.contour(xx, yy, Z, colors='red', linestyles='-')

    Z_uni = clf_uni.predict(np.c_[xx.ravel(), yy.ravel()])
    Z_uni = Z_uni.reshape(xx.shape)

    contour2= plt.contour(xx, yy, Z_uni, colors='k', linestyles='dashed' )

    # Creating a custom legend
    custom_legend = [
        Line2D([0], [0], linestyle='-',  color='r', label='SGD-Q boundary'),
        Line2D([0], [0], linestyle='--',  color='k', label='Vanilla SGD boundary'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', label='Posteriors q')
    ]

    pos_1 = np.where(y == 1)

    pos = np.where(y == 0)


    scatter3 =plt.scatter(X[pos_1, 0], X[pos_1, 1],  c=p_sample[pos_1],  s=60,cmap='gray',label = 'Posteriors q' )

    plt.scatter(X[pos, 0], X[pos, 1], c= p_sample[pos] ,  s=60,cmap='gray'  )

    scatter1 = plt.scatter(X[pos, 0], X[pos, 1], marker='s', edgecolors='blue', s=100, facecolor='none',label='Class 0')
    scatter2 =plt.scatter(X[pos_1, 0], X[pos_1, 1], marker='D', edgecolors='red', s=100, facecolor='none',label = 'Class 1')

    plt.title('Decision Boundary', fontsize=20)
    plt.xlabel('Feature 1', fontsize=20)
    plt.ylabel('Feature 2', fontsize=20)
    plt.legend(handles=custom_legend+[scatter1,scatter2])
    plt.savefig('./visual_sgd_trained.pdf', bbox_inches='tight', pad_inches=0.02)
    plt.show()

plot_decision_boundary(custom_sgd,custom_sgd_uni, X, y)
