import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tensorflow.keras.datasets import mnist

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

def weight_variable(shape):
    initial = tf.random.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool2d(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

class ModelCNNMnist:
    def __init__(self):
        self.build_model()

    def build_model(self):
        self.model = tf.keras.Sequential([
            tf.keras.layers.Reshape(target_shape=[28, 28, 1], input_shape=(784,)),
            tf.keras.layers.Conv2D(32, (5, 5), activation='relu', padding='same'),
            tf.keras.layers.MaxPooling2D((2, 2), strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(32, (5, 5), activation='relu', padding='same'),
            tf.keras.layers.MaxPooling2D((2, 2), strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                           loss='categorical_crossentropy',
                           metrics=['accuracy'])

def plot_loss_landscape(model, data, steps=30, range_scale=1.0):
    x_vals = np.linspace(-range_scale, range_scale, steps)
    y_vals = np.linspace(-range_scale, range_scale, steps)
    z_vals = np.zeros((steps, steps))

    weights = model.model.get_weights()

    for i, x in enumerate(x_vals):
        for j, y in enumerate(y_vals):
            perturbed_weights = [w + np.random.normal(size=w.shape) * x for w in weights]
            perturbed_weights = [w + np.random.normal(size=w.shape) * y for w in perturbed_weights]
            model.model.set_weights(perturbed_weights)
            loss = model.model.evaluate(data[0], data[1], verbose=0)[0]
            z_vals[i, j] = loss

    X, Y = np.meshgrid(x_vals, y_vals)
    Z = z_vals

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(X, Y, Z, cmap='viridis')

    ax.set_xlabel('X direction')
    ax.set_ylabel('Y direction')
    ax.set_zlabel('Loss')

    plt.show()

if __name__ == '__main__':
    (train_images, train_labels), (_, _) = mnist.load_data()
    train_images = train_images.reshape(-1, 784) / 255.0
    train_labels = tf.keras.utils.to_categorical(train_labels, 10)

    model = ModelCNNMnist()

    plot_loss_landscape(model, (train_images, train_labels), steps=30, range_scale=0.1)
