from typing import Literal
import matplotlib.patches
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from IPython.display import clear_output
import matplotlib
import umap
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from dataset_5_layer.data_utils.real_simulator_tmm import srmse_evaluate

def create_train_val_split(x_train, y_train, test_ratio=0.2, seed=None):
    '''
    split a dataset into train and test set with the specified ratio
    '''
    np.random.seed(seed)
    n_examples = x_train.shape[0]
    shuffle = np.random.permutation(n_examples)
    n_val = int(n_examples * test_ratio)

    test_indices = shuffle[:n_val]
    train_indices = shuffle[n_val:]
    
    x_train_test = x_train[test_indices]
    y_train_test = y_train[test_indices]
    x_train_train = x_train[train_indices]
    y_train_train = y_train[train_indices]
    
    return x_train_train, x_train_test, y_train_train, y_train_test

def scaler(x, n_layer, translate_x=False):
    '''
    Scale the thickness in the range (-1,1). If [translate_x] = True scales in the range(0,2)
    patch e rsepresent the material thickness
    '''
    x[:, -n_layer:] /= 30e-9
    if not translate_x:
        x[:, -n_layer:] -=1

    return x

def unscaler(x, n_layer, translate_x=False):
    '''
    Unscale the thickness values to the original range 1-60nm (nanometers 10^-9)

    [translate_x] must be the same as the one used for scaler in order to revert to the saem initial values
    
    Accepts an input in the following structure:
    BATCH_SIZE x [num_layer * num_materials, num_layer]

    where the last [num_layer] cells represent the material thickness
    '''
    if not translate_x:
        x[:, -n_layer:] +=1

    x[:, -n_layer:] *= 30e-9
    return x

def decode_materials(x_hat, num_layers, num_materials=5):
    '''
        Decode the material logits to one-hot format

        Example: [0.2, 0.1, 0.7] -> [0.,0.,1.]
    '''
    batch_size = x_hat.shape[0]
    
    # Extract material logits and reshape
    materials_pred = x_hat[:, :num_layers * num_materials].view(batch_size, num_layers, num_materials)
    
    # Convert logits to one-hot
    material_indices = materials_pred.argmax(dim=-1)
    materials_one_hot = torch.nn.functional.one_hot(material_indices, num_classes=num_materials).float()
    
    # Flatten back to original format
    return materials_one_hot.view(batch_size, num_layers * num_materials)


def evaluate_single_srmse(decoded_material, desired_material, metamat_config):
    if len(decoded_material.shape) < 2:
        decoded_material = decoded_material.unsqueeze(0)
    
    if len(desired_material.shape) < 2:
        desired_material = desired_material.unsqueeze(0)
    
    decoded_material = decoded_material.cpu().detach().numpy()
    desired_material = desired_material.cpu().detach().numpy()
    
    mean, std, ci, rmse_samples, ytrue_ypred, rmse_waves = srmse_evaluate(
        np.array(decoded_material),
        np.array(desired_material),
        metamat_config
    )

    pred = ytrue_ypred[1]

    return mean #, pred

def evaluate_reconstruction(X_test, dec, n_layer, n_mat=5, log_file=None):
    """
    Evaluates reconstruction metrics for a material autoencoder.
    """
    Y_test = dec
    one_hot_X = X_test[:, :-n_layer]
    one_hot_Y = decode_materials(Y_test[:, :-n_layer], n_layer, num_materials=n_mat)
    
    # One-hot accuracy
    correct = (one_hot_X.view(one_hot_X.shape[0], -1, n_mat) == one_hot_Y.view(one_hot_Y.shape[0], -1, n_mat)).all(dim=-1).sum().item()
    total = one_hot_X.numel() // n_mat
    
    # Confidence mean
    values = Y_test[:, :-n_layer].reshape(-1, n_mat).square().sum(dim=1)
    confidence_mean = values.mean().item()
    
    # Thickness metrics
    thickness_X = X_test[:, -n_layer:]
    thickness_Y = Y_test[:, -n_layer:]
    mse = F.mse_loss(thickness_X, thickness_Y)
    rmse = torch.sqrt(mse)
    mae = F.l1_loss(thickness_X, thickness_Y)
    
    log_message = (f"One-Hot accuracy:  {correct / total:.3f}\n"
                   f"One-Hot confidence mean: {confidence_mean:.3f}\n"
                   f"Thickness RMSE: {rmse:.2e}\n"
                   f"Thickness MSE: {mse:.2e}\n"
                   f"Thickness MAE: {mae:.2e}\n")
    
    print(log_message)
    if log_file != None:
        with open(f"{log_file}.txt", "a") as f:
            f.write(log_message + "\n")


def evaluate_one_hot_gidnet(X_test, n_layer, n_mat=5, log_file=None, quiet=True):
    '''
    Evaluates some metrics over the one-hot accuracy of the gidnet generated material
    '''
    values = X_test[:, :-n_layer].reshape(-1, n_mat).square().sum(dim=1)

    std, mean = torch.std_mean(values)

    log = f"One-hot accuracy: {mean.item():4f}+-({std.item():.4f})"
    if not quiet:
        print(log)

    if log_file != None:
        with open(log_file, "a") as f:
            f.write(log + "\n")
    
    return std, mean


def plot_wave_charts(data_points, predicted_data_points=None):
    i = 0
    n_graph = 0
    dim = 200

    wavelength = [int(p) for p in np.linspace(450, 950, 200)]

    # rs, rp, ts, tp
    labels = []
    labels.append([f"reflectance S @ {i}°" for i in [25, 45, 65]])
    labels.append([f"reflectance P @ {i}°" for i in [25, 45, 65]])
    labels.append([f"transmittance S @ {i}°" for i in [25, 45, 65]])
    labels.append([f"transmittance P @ {i}°" for i in [25, 45, 65]])
  
    plt.figure(figsize=(20,15))
    while i < len(data_points):
        data = data_points[i:i+dim]
        predicted_data = None
        predicted_data = predicted_data_points[i:i+dim] if predicted_data_points is not None else None

        lbl = labels[n_graph // 3][n_graph % 3]
       
        plt.subplot(4, 3, n_graph + 1)
        plt.plot(wavelength, data, linewidth=1, color='blue', label="Actual")
        plt.title(lbl)
        plt.xlabel(r"Wavelength [$\mathrm{nm}$]")
        #plt.ylim([0,1])

        if predicted_data is not None:
            plt.plot(wavelength, predicted_data, linewidth=1, color='orange', label="Predicted")
            #plt.ylim([0,1])
            #plt.xlabel(lbl)


        i += dim
        n_graph += 1

        plt.legend()
        
    plt.tight_layout()

    plt.savefig("spectra.png", dpi=400)
    plt.show()


def plot_material_square(material, scale_to_nm=False, ax=None, square_height=10):
    base_colors = np.array([
        [1, 0, 0],    # Red 
        [0, 1, 0],    # Green
        [0, 0, 1],    # Blue
        [1, 0.5, 0],  # Orange
        [0, 1, 1],    # Cyan
    ])

    # Convert one hot part to a list of indices (corresponding to material)
    # Es. [0 0 1 0 0] -> 2
    one_hot = material[:25].reshape(5, 5) 

    # Interpolate colors based on the one hot accuracy of the material
    # Es. mat -> [0 0 0.9 0 0.1] color -> base_colors[2] * 0.9 + base_colors[4] * 0.1
    interpolated_colors = []
    for row in one_hot:
        interpolated_colors.append(np.clip(np.sum(row * base_colors.T, axis=1), 0, 1))

    thickness_values = material[25:]

    if scale_to_nm:
        thickness_values += 1
        thickness_values *= 30e-9

    total_thickness = sum(thickness_values)
    # Normalize thicknesses so that they sum up to square_height.
    normalized_thicknesses = [t / total_thickness * square_height for t in thickness_values]

    if ax == None:
        fig, ax = plt.subplots(1, 1, figsize=(5, 6))

    # --------- Plot the material square ------------------

    # Set limits so that the square (of height 'square_height') is centered.
    ax.set_xlim(-square_height / 2, square_height / 2)
    ax.set_ylim(-square_height / 2, square_height / 2)

    # Remove ticcount_consecutive_repeated_layersks and frame for a cleaner look.
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)

    # Start at the bottom of the square to draw horizontal bands
    current_y = -square_height / 2
    for material_idx, h in zip(range(len(normalized_thicknesses)), normalized_thicknesses):
        #print(interpolated_colors[idx])
        # Create a rectangle for this layer.
        rect =  matplotlib.patches.Rectangle(
            (-square_height / 2, current_y),   # (x,y) of lower-left corner
            square_height,                     # width of the square
            h,                                 # height for this layer
            facecolor=interpolated_colors[material_idx]
        )
        ax.add_patch(rect)
        # Move up to draw next layer, adding a bit of padding
        current_y += h + 0.1  

    return ax




def plot_latent_material(path_embeddings, umap_labels, index, decoded_materials, materials_labels, scale_to_nm=False, square_height=10, save_frame=False):
    '''
    Given the umap embeddings of a path in the latent space, plot it along with the decoded material

    [umap_path_embeddings] is a list of the embeddings produced by umap
    [umap labels] list of labels to show for each plot
    [index] refers to which point to highligth in the path plot
    [decoded_materials] is a list of decoded material to be plotted (length must be equal to umap embeddings)
    [scale_to_nm] = True is necessary if the material thickness is in range (-1,1)
    '''
    clear_output(wait=True)


    num_materials = len(decoded_materials)
    fig, ax = plt.subplots(1, num_materials + 2, figsize=(5 + 5*num_materials, 6))


    # -------- Plot The latent space projection -----------

    titles = ["VAE latent space", "Gidnet latent space"]
    color_map = {0: "black", 1: "red", 10: "green"}
    for idx, embedding_list in enumerate(path_embeddings):
        for plot, embedding in enumerate(embedding_list):
            # Create one hot encoding of colors to show highlithed point
            one_hot_colors = np.zeros(embedding.shape[0]) + (plot * 10) 
            one_hot_colors[index] = 1
            one_hot_colors = [color_map[c] for c in one_hot_colors]

            #ax[0].set_xticks([])
            #ax[0].set_yticks([])
            sns.scatterplot(x=embedding[:,0], y=embedding[:,1], c=one_hot_colors, label=umap_labels[plot], ax=ax[idx])
            ax[idx].set_title(titles[idx])

    for idx, material in enumerate(decoded_materials):
        plot_material_square(material, scale_to_nm, ax[2 + idx])
        ax[2 + idx].set_title(materials_labels[idx])


    if save_frame:
        filename = f"frame_{index}.png"
        plt.savefig(filename)
        plt.close(fig) 
        return filename

    plt.show()    


def compute_material_interpolation(latent_mat_1, latent_mat_2, steps = 100):
    '''
    Compute the linear interpolation between two materials in the latent space

    alpha * mat1 + (1 - alpha) * mat2

    [steps] determine the number of points to return
    '''

    res = []
    delta = 1000 // steps
    for alpha in range(1000, -1, -delta):
        alpha /= 1000
        res.append(latent_mat_1 * alpha + latent_mat_2 * (1 - alpha))

    res = torch.stack(res)

    return res


"""def compute_latent_projection(path1, path2, method: Literal['pca', 'umap', 'tsne'] = 'pca'):
    reducer = None

    if method == 'umap':
        reducer = umap.UMAP(n_neighbors=20)
    elif method == 'pca':
        reducer = PCA(n_components=2)
    elif method == 'tsne':
        reducer = TSNE(n_components=2, perplexity=20)

    path1_embedding = None
    path2_embedding = None

    if method == 'tsne':
        path1_embedding = reducer.fit_transform(path1.detach().cpu())
        path2_embedding = reducer.fit_transform(path2.detach().cpu())

    else:
        union = torch.cat([path1.cuda(), path2.cuda()], dim=0).detach().cpu()
        reducer.fit(union)                                         # type: ignore
        #print(f"{reducer.explained_variance_ratio_} %")            # type: ignore
        path1_embedding = reducer.transform(path1.detach().cpu())  # type: ignore
        path2_embedding = reducer.transform(path2.detach().cpu())  # type: ignore

    return path1_embedding, path2_embedding"""



from mpl_toolkits.mplot3d.art3d import Poly3DCollection

def draw_rectangle(ax, x, y, z, dz, dx = 2, dy = 2, color = 'red'):
    # Define the cuboid's corner (origin) and size

    # Define vertices of the cuboid
    vertices = [
        [x, y, z],
        [x + dx, y, z],
        [x + dx, y + dy, z],
        [x, y + dy, z],
        [x, y, z + dz],
        [x + dx, y, z + dz],
        [x + dx, y + dy, z + dz],
        [x, y + dy, z + dz]
    ]

    # Define the 6 faces using the vertices above
    faces = [
        [vertices[0], vertices[1], vertices[2], vertices[3]],  # bottom
        [vertices[4], vertices[5], vertices[6], vertices[7]],  # top
        [vertices[0], vertices[1], vertices[5], vertices[4]],  # front
        [vertices[2], vertices[3], vertices[7], vertices[6]],  # back
        [vertices[1], vertices[2], vertices[6], vertices[5]],  # right
        [vertices[0], vertices[3], vertices[7], vertices[4]]   # left
    ]

    # Draw the cuboid
    ax.add_collection3d(Poly3DCollection(faces, facecolors=color, linewidths=0.3, edgecolors='darkgray', zsort="min"))



def draw_3d_material(materials, filename, add_labels=True):
    # Create a 3D plot
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111, projection='3d')

    tot_thic = 0
    for mat, thic in materials:
        tot_thic += thic

    ax.grid(False)
    ax.set_xticks([])         # Remove X ticks
    ax.set_yticks([])         # Remove Y ticks
    ax.set_zticks([])         # type: ignore # Remove Z ticks
    ax.set_facecolor((1, 1, 1, 0))  # Make the background transparent
    ax.xaxis.pane.set_visible(False)  # type: ignore
    ax.yaxis.pane.set_visible(False)  # type: ignore
    ax.zaxis.pane.set_visible(False)  # type: ignore
    ax.xaxis.line.set_color((0., 0., 0., 0.))  # type: ignore
    ax.yaxis.line.set_color((0., 0., 0., 0.))  # type: ignore
    ax.zaxis.line.set_color((0., 0., 0., 0.))  # type: ignore

    # Set axes limits for better view
    ax.set_xlim([0, 8])  # type: ignore
    ax.set_ylim([0, 8])  # type: ignore
    ax.set_zlim([0, 8])  # type: ignore

    MAT_WIDTH = 3
    BASELINE_WIDTH = 3.5
    TOT_HEIGHT = 3

    curr_x = 1
    curr_y = 2
    curr_z = 0

    material_colors = {
        "Ag":    "#C0C0C0",  "Al2O3": "#5595F4", "ITO": "#1ACE5F",
        "Ni":    "#4B4B4B",  "TiO2":  "#D4AF37",
    }

    # Draw the baseline rectangle
    draw_rectangle(ax, 
                   x = curr_x - (abs(BASELINE_WIDTH - MAT_WIDTH) / 2) - 0.25, y = curr_y - (abs(BASELINE_WIDTH - MAT_WIDTH) / 2) + 0.35, z = curr_z, 
                   dx = BASELINE_WIDTH, dy = BASELINE_WIDTH, dz = 0.2, color="#42347F")  # type: ignore
    
    curr_z += 0.6

    for mat, thic in materials:
        mat_z = (thic / tot_thic) * TOT_HEIGHT

        # Draw a rectangle for every layer
        draw_rectangle(ax, x = curr_x, y = curr_y, z = curr_z, dx = MAT_WIDTH, dy = MAT_WIDTH, dz = mat_z, color = material_colors[mat])

        pos_x = curr_x + MAT_WIDTH
        pos_y = curr_y + MAT_WIDTH
        pos_z = curr_z + mat_z/2

        if add_labels:
            ax.plot([pos_x, pos_x + 1], [pos_y, pos_y + 1], [pos_z, pos_z + pos_z], c="gray")
            ax.plot([pos_x + 1, pos_x + 2], [pos_y + 1, pos_y + 1.3], [pos_z+pos_z, pos_z+pos_z], c="gray")
            ax.text(pos_x + 2.1, pos_y + 1.3, pos_z + pos_z, fr"{mat}  [{thic:.3f}] nm", fontsize=9)  # type: ignore

        curr_z += mat_z

    plt.savefig(filename, dpi=300)
    plt.show()



def material_tensor_to_list_tuples(tensor_mat):
    materials = ["Ag", "Al2O3", "ITO", "Ni", "TiO2"]

    mats = tensor_mat[:25].reshape(5,5).argmax(dim=1)
    # From m to nm
    thics = tensor_mat[25:] * 1e9

    final_mat = []
    for i in range(len(mats)):
        final_mat.append((materials[mats[i]], float(thics[i])))

    return final_mat