import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import fsolve
from scipy.integrate import quad
from scipy.optimize import brentq
from scipy.stats import truncnorm


# Parameters for both scenarios
alpha = 0.5
rho = 0.8

# c range
c_range = np.linspace(0.001, 1, 1000)  # Adjust as needed

mu = 0.5  # Example mean within (0,1)
sigma = 0.1  # Example standard deviation

# Define the truncated normal density functions
def p(v, mu=mu, sigma=sigma):
    a, b = (0 - mu) / sigma, (1 - mu) / sigma
    return truncnorm.pdf(v, a, b, loc=mu, scale=sigma)

def p_rho(v, rho, mu=mu, sigma=sigma):
    a, b = (0 - rho * mu) / sigma, (1 - rho * mu) / sigma
    return truncnorm.pdf(v, a, b, loc=rho * mu, scale=sigma)

# Define the CDF of the truncated normal distribution for p
def F_p(t, mu=mu, sigma=sigma):
    a, b = (0 - mu) / sigma, (1 - mu) / sigma
    return truncnorm.cdf(t, a, b, loc=mu, scale=sigma)

# Function to compute S_1(A) and S_2(A)
def compute_S(alpha, rho, c):
    def equation_for_t_rho(v):
        return (1 - alpha) * F_p(v, mu, sigma) + alpha * F_p(v, rho * mu, sigma) - (1 - c)

    t_rho = brentq(equation_for_t_rho, 0, 1)
    S_1_integral, _ = quad(lambda u: u * p(u, mu, sigma), t_rho, 1)
    S_2_integral, _ = quad(lambda u: u * p_rho(u, rho, mu, sigma), t_rho, 1)
    S_1 = S_1_integral - (1 - F_p(t_rho, mu, sigma)) * t_rho
    S_2 = S_2_integral - (1 - F_p(t_rho, rho*mu, sigma)) * t_rho
    return S_1, S_2

S_1_values = []
S_2_values = []

# Compute S_1(A) and S_2(A) for each rho
for c in c_range:
    S_1, S_2 = compute_S(alpha, rho, c)
    S_1_values.append(S_1)
    S_2_values.append(S_2)

S2_S1_ratio_truncated = np.array(S_2_values) / np.array(S_1_values)

# Find c where S_2/S_1 is closest to 0.8
index_closest_to_08 = np.abs(S2_S1_ratio_truncated - 0.8).argmin()
c_closest_to_08 = c_range[index_closest_to_08]

# Find c where S_2/S_1 is closest to 0.8
index_closest_to_08 = np.abs(S2_S1_ratio_truncated - 0.8).argmin()
c_closest_to_08 = c_range[index_closest_to_08]

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(c_range, S2_S1_ratio_truncated, label=r'$r_{\mathcal{S}}(A)$ for truncated normal', color='blue')

# Plot the modified axhline and axvline
plt.plot([0, c_closest_to_08], [0.8, 0.8], color='gray', linestyle='--')  # Horizontal line
plt.plot([c_closest_to_08, c_closest_to_08], [0, 0.8], color='gray', linestyle='--')  # Vertical line

plt.scatter([c_closest_to_08], [0.8], color='gray', zorder=5)  # Mark the intersection

# Calculate offsets
xlim = plt.xlim()  # Get the current x-axis limits
ylim = plt.ylim()  # Get the current y-axis limits
x_offset = (xlim[1] - xlim[0]) * 0.01  # Calculate x offset as a percentage of the x-axis range
y_offset = (ylim[1] - ylim[0]) * 0.02  # Calculate y offset as a percentage of the y-axis range

# Place the text left (with x_offset) and above (with y_offset) the scatter point
plt.text(c_closest_to_08 - x_offset, 0.8 - y_offset, f'({c_closest_to_08:.3f}, 0.8)',
         verticalalignment='top', horizontalalignment='right', fontsize=25)

plt.title(r'Plot of $r_{\mathcal{S}}(A)$ for $\rho = 0.8$, $\alpha = 0.5$', fontsize=25)
plt.xlabel(r'$c$', fontsize=25)
plt.ylabel('Social welfare ratio', fontsize=25)
# plt.legend()
plt.grid(True)

# Adjust tick label sizes
plt.xticks(fontsize=25)
plt.yticks(fontsize=25)
plt.show()
