import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np
import torch
import torch.nn.functional as F

def decode_materials(x_hat, num_layers, num_materials=5):
    '''
        Decode the material logits to one-hot format

        Example: [0.2, 0.1, 0.7] -> [0.,0.,1.]
    '''
    batch_size = x_hat.shape[0]
    
    # Extract material logits and reshape
    materials_pred = x_hat[:, :num_layers * num_materials].view(batch_size, num_layers, num_materials)
    
    # Convert logits to one-hot
    material_indices = materials_pred.argmax(dim=-1)
    materials_one_hot = torch.nn.functional.one_hot(material_indices, num_classes=num_materials).float()
    
    # Flatten back to original format
    return materials_one_hot.view(batch_size, num_layers * num_materials)

def evaluate_reconstruction(X_test, dec, n_layer, n_mat=5, log_file=None):
    """
    Evaluates reconstruction metrics for a material autoencoder.
    """
    Y_test = dec
    one_hot_X = X_test[:, :-n_layer]
    one_hot_Y = decode_materials(Y_test[:, :-n_layer], n_layer, num_materials=n_mat)
    
    # One-hot accuracy
    correct = (one_hot_X.view(one_hot_X.shape[0], -1, n_mat) == one_hot_Y.view(one_hot_Y.shape[0], -1, n_mat)).all(dim=-1).sum().item()
    total = one_hot_X.numel() // n_mat
    
    # Confidence mean
    values = Y_test[:, :-n_layer].reshape(-1, n_mat).square().sum(dim=1)
    confidence_mean = values.mean().item()
    
    # Thickness metrics
    thickness_X = X_test[:, -n_layer:]
    thickness_Y = Y_test[:, -n_layer:]
    mse = F.mse_loss(thickness_X, thickness_Y)
    rmse = torch.sqrt(mse)
    mae = F.l1_loss(thickness_X, thickness_Y)
    
    log_message = (f"One-Hot accuracy:  {correct / total:.3f}\n"
                   f"One-Hot confidence mean: {confidence_mean:.3f}\n"
                   f"Thickness RMSE: {rmse:.2e}\n"
                   f"Thickness MSE: {mse:.2e}\n"
                   f"Thickness MAE: {mae:.2e}\n")
    
    print(log_message)

    

def plot_wave_charts(data_points, predicted_data_points=None):
    wavelength = np.linspace(0.3, 20, 2001)
    plt.figure(figsize=(10,5))

    data_points = data_points.numpy()
  
    plt.plot(wavelength, data_points, linewidth=1, color='blue', label="Actual")
    plt.xlabel("Wavelength")

    if predicted_data_points is not None:
        predicted_data_points = predicted_data_points.numpy()

        plt.plot(wavelength, predicted_data_points, linewidth=1, color='orange', label="Predicted")
        #plt.ylim([0,1])
        plt.xlabel(r"Wavelength $[\mu\mathrm{m}]$")

    plt.title("Average reflectance (TE, TM polarizations - 15 angles)")
    plt.legend()
    plt.savefig("spectra.png", dpi=400)
    plt.show()


def draw_rectangle(ax, x, y, z, dz, dx = 2, dy = 2, color = 'red'):
    # Define the cuboid's corner (origin) and size

    # Define vertices of the cuboid
    vertices = [
        [x, y, z],
        [x + dx, y, z],
        [x + dx, y + dy, z],
        [x, y + dy, z],
        [x, y, z + dz],
        [x + dx, y, z + dz],
        [x + dx, y + dy, z + dz],
        [x, y + dy, z + dz]
    ]

    # Define the 6 faces using the vertices above
    faces = [
        [vertices[0], vertices[1], vertices[2], vertices[3]],  # bottom
        [vertices[4], vertices[5], vertices[6], vertices[7]],  # top
        [vertices[0], vertices[1], vertices[5], vertices[4]],  # front
        [vertices[2], vertices[3], vertices[7], vertices[6]],  # back
        [vertices[1], vertices[2], vertices[6], vertices[5]],  # right
        [vertices[0], vertices[3], vertices[7], vertices[4]]   # left
    ]

    # Draw the cuboid
    ax.add_collection3d(Poly3DCollection(faces, facecolors=color, linewidths=0.3, edgecolors='darkgray', zsort="min"))



def draw_3d_material(materials, filename):
    # Create a 3D plot
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111, projection='3d')

    tot_thic = 0
    for mat, thic in materials:
        tot_thic += thic

    ax.grid(False)
    ax.set_xticks([])         # Remove X ticks
    ax.set_yticks([])         # Remove Y ticks
    ax.set_zticks([])         # type: ignore # Remove Z ticks
    ax.set_facecolor((1, 1, 1, 0))  # Make the background transparent
    ax.xaxis.pane.set_visible(False)  # type: ignore
    ax.yaxis.pane.set_visible(False)  # type: ignore
    ax.zaxis.pane.set_visible(False)  # type: ignore
    ax.xaxis.line.set_color((0., 0., 0., 0.))  # type: ignore
    ax.yaxis.line.set_color((0., 0., 0., 0.))  # type: ignore
    ax.zaxis.line.set_color((0., 0., 0., 0.))  # type: ignore

    # Set axes limits for better view
    ax.set_xlim([0, 8])  # type: ignore
    ax.set_ylim([0, 8])  # type: ignore
    ax.set_zlim([0, 10])  # type: ignore

    MAT_WIDTH = 3
    BASELINE_WIDTH = 3.5
    TOT_HEIGHT = 6

    curr_x = 1
    curr_y = 2
    curr_z = 0


    material_colors = {
        "Ag":    "#5F5F5F",  "SiO2": "#5595F4", "SiC": "#1ACE5F", "AlN": "#055F4E",
        "MgF2":    "#843030",  "TiO2":  "#FFC404", "ZnO": "#CDC075", "Al2O3": "#C4C4C4"
    }

    # Draw the baseline rectangle
    draw_rectangle(ax, 
                   x = curr_x - (abs(BASELINE_WIDTH - MAT_WIDTH) / 2) - 0.25, y = curr_y - (abs(BASELINE_WIDTH - MAT_WIDTH) / 2) + 0.35, z = curr_z, 
                   dx = BASELINE_WIDTH, dy = BASELINE_WIDTH, dz = 0.2, color="#42347F")  # type: ignore
    
    curr_z += 0.6

    for mat, thic in materials:
        mat_z = (thic / tot_thic) * TOT_HEIGHT
        if mat_z < 0.2:
            mat_z = 0.2

        # Draw a rectangle for every layer
        draw_rectangle(ax, x = curr_x, y = curr_y, z = curr_z, dx = MAT_WIDTH, dy = MAT_WIDTH, dz = mat_z, color = material_colors[mat])

        pos_x = curr_x + MAT_WIDTH
        pos_y = curr_y + MAT_WIDTH
        pos_z = curr_z + mat_z/2

        # Draw the lines that start from the rectangle and lead to the label
        ax.plot([pos_x, pos_x + 1], [pos_y, pos_y + 1], [pos_z, pos_z + pos_z], c="gray")
        ax.plot([pos_x + 1, pos_x + 2], [pos_y + 1, pos_y + 1.3], [pos_z + pos_z,  pos_z + pos_z], c="gray")
        ax.text(pos_x + 2.1, pos_y + 1.3, pos_z + pos_z , fr"{mat}  [{thic:.3f}] um", fontsize=9)  # type: ignore

        curr_z += mat_z

    plt.savefig(filename, dpi=300)
    plt.show()