from matplotlib import pyplot as plt
import numpy as np
import plotly.graph_objs as go
import matplotlib

dpi = 300




# Function to compute Gini index
def gini_index(w1, w2):
    w3 = 1 - w1 - w2
    if w1 < 0 or w2 < 0 or w3 < 0:
        return np.nan  # Invalid region
    x = np.sort([w1, w2, w3])  # Sort components
    n = 3
    return 1 - (np.sum((np.arange(n, 0, -1) * x)) / (n * np.sum(x)))

# Function to compute PQ index
def pq_index(w1, w2):
    w3 = 1 - w1 - w2
    if w1 < 0 or w2 < 0 or w3 < 0:
        return np.nan  # Invalid region
    NmP = 1  # l1 norm is always 1
    NmQ = np.sqrt(w1**2 + w2**2 + w3**2)  # l2 norm
    PQI = 1 - 1 / (NmQ * np.sqrt(3))
    return PQI

# Generate a grid of points for w1, w2
w1 = np.linspace(0, 1, 100)
w2 = np.linspace(0, 1, 100)
w1_grid, w2_grid = np.meshgrid(w1, w2)
w1_grid = w1_grid.flatten()
w2_grid = w2_grid.flatten()

# Calculate Gini index values for the grid
gini_values = [gini_index(w1, w2) for w1, w2 in zip(w1_grid, w2_grid)]
valid_gini_indices = [g for g in gini_values if not np.isnan(g)]  # Filter out invalid points
valid_w1_w2 = [(w1, w2) for w1, w2, g in zip(w1_grid, w2_grid, gini_values) if not np.isnan(g)]

# Create 3D plot for Gini index
gini_fig = go.Figure(data=[go.Scatter3d(
    x=[point[0] for point in valid_w1_w2],
    y=[point[1] for point in valid_w1_w2],
    z=valid_gini_indices,
    mode='markers',
    marker=dict(size=5, color=valid_gini_indices, colorscale='Viridis')  # Original marker size
)])

# Update Gini index plot layout to provide more space around the plot and scale figure size
gini_fig.update_layout(
    scene=dict(
        xaxis=dict(title='w1'),
        yaxis=dict(title='w2'),
        zaxis=dict(title='', range=[min(valid_gini_indices)-0.1, max(valid_gini_indices)+0.1], tickvals=[0.4, 0.6]),  # Adjust range
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=2.10)  # Adjust x, y, z for desired angle
        )
    ),
    margin=dict(l=0, r=0, b=50, t=50),  # Adjust margins
    width=600,  # Set figure width (scaled down)
    height=450,  # Set figure height (scaled down)
    title=''
)


gini_fig.update_layout(
    font=dict(family="Times New Roman, serif", color="black", weight="bold", size = 16),
    xaxis=dict(
        tickfont=dict(
            family="Times New Roman, serif",
            size=25,  # You can adjust the size as needed
            color="black",
            weight="bold"
        )
    ),
    yaxis=dict(
        tickfont=dict(
            family="Times New Roman, serif",
            size=25,
            color="black",
            weight="bold"
        )
    )
)

# Save the Gini index plot as a PDF file
gini_fig.write_image("gini_index_plot_scaled.pdf")

# Calculate PQ index values for the grid
pqi_values = [pq_index(w1, w2) for w1, w2 in zip(w1_grid, w2_grid)]
valid_pqi_indices = [p for p in pqi_values if not np.isnan(p)]  # Filter out invalid points
valid_w1_w2_pqi = [(w1, w2) for w1, w2, p in zip(w1_grid, w2_grid, pqi_values) if not np.isnan(p)]

# Create 3D plot for PQ index
pqi_fig = go.Figure(data=[go.Scatter3d(
    x=[point[0] for point in valid_w1_w2_pqi],
    y=[point[1] for point in valid_w1_w2_pqi],
    z=valid_pqi_indices,
    mode='markers',
    marker=dict(size=5, color=valid_pqi_indices, colorscale='Viridis')  # Original marker size
)])

# Update PQ index plot layout to provide more space around the plot and scale figure size
pqi_fig.update_layout(
    scene=dict(
        xaxis=dict(title='w1'),
        yaxis=dict(title='w2'),
        zaxis=dict(title='', range=[min(valid_pqi_indices)-0.1, max(valid_pqi_indices)+0.1], tickvals=[0.2, 0.4]),  # Adjust range
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=2.10)  # Adjust x, y, z for desired angle
        )
    ),
    margin=dict(l=0, r=0, b=50, t=50),  # Adjust margins
    width=600,  # Set figure width (scaled down)
    height=450,  # Set figure height (scaled down)
    title=''
)


pqi_fig.update_layout(
    font=dict(family="Times New Roman, serif", color="black", weight="bold", size=16),
    xaxis=dict(
        title = {'text': 'w1', 'font': {'size': 25}},
        tickfont=dict(
            family="Times New Roman, serif",
            size=20,  # You can adjust the size as needed
            color="black",
            weight="bold"
        )
    ),
    yaxis=dict(
        title = {'text': 'w2', 'font': {'size': 30}},
        tickfont=dict(
            family="Times New Roman, serif",
            size=25,
            color="black",
            weight="bold"
        )
    ),
)

# Save the PQ index plot as a PDF file
pqi_fig.write_image("pq_index_plot_scaled.pdf")

