import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import fsolve
from scipy.integrate import quad
from scipy.optimize import brentq
from scipy.stats import truncnorm


# Parameters for both scenarios
alpha = 0.5
rho = 0.8
beta = 2  # Pareto distribution parameter, beta > 1

# c range
c_range = np.linspace(0.001, 1, 1000)  # Adjust as needed

# CDF of the Pareto distribution
def F_p_pareto(u):
    return 1 - 1 / u**beta if u >= 1 else 0

# Adjusted CDF F_p_rho for given rho
def F_p_rho_pareto(u, rho):
    return F_p(u / rho) if u >= rho else 0

# Equation to solve for t_rho based on alpha, rho, and c
def equation_for_t_rho(u, alpha, rho, c, beta):
    term1 = (1 - alpha) * max(0, 1 - 1 / u**beta)
    term2 = alpha * max(0, 1 - (rho**beta) / u**beta)
    return term1 + term2 - (1 - c)

# Function to find t_rho using fsolve for Pareto
def solve_t_rho(alpha, rho, c, beta):
    t_rho_initial_guess = rho + 0.01
    solution = fsolve(equation_for_t_rho, t_rho_initial_guess, args=(alpha, rho, c, beta))[0]
    return max(solution, rho)


mu = 0.5  # Example mean within (0,1)
sigma = 0.1  # Example standard deviation

# Define the CDF of the truncated normal distribution for p
def F_p(t, mu=mu, sigma=sigma):
    a, b = (0 - mu) / sigma, (1 - mu) / sigma
    return truncnorm.cdf(t, a, b, loc=mu, scale=sigma)

def F_p_rho(t, rho, mu=mu, sigma=sigma):
    a, b = (0 - rho * mu) / sigma, (1 - rho * mu) / sigma
    return truncnorm.cdf(t, a, b, loc=rho * mu, scale=sigma)

# Equation to solve for t_rho
def equation_for_t_rho(v, alpha, rho, c):
    return (1 - alpha) * F_p(v, mu, sigma) + alpha * F_p_rho(v, rho, mu, sigma) - (1 - c)

# Function to solve for t_rho using brentq
def find_t_rho(alpha, rho, c):
    t_rho_solution = brentq(equation_for_t_rho, 0, 1, args=(alpha, rho, c))
    return t_rho_solution

R_1_values = []
R_2_values = []

# Compute R_1(A) and R_2(A) for each c
for c in c_range:
    t_rho = find_t_rho(alpha, rho, c)
    R_1 = 1 - F_p(t_rho, mu, sigma)
    R_2 = 1 - F_p_rho(t_rho, rho, mu, sigma)
    R_1_values.append(R_1)
    R_2_values.append(R_2)

R2_R1_ratio_normal = np.array(R_2_values) / np.array(R_1_values)

# Find c where R_2/R_1 is closest to 0.8
index_closest_to_08 = np.abs(R2_R1_ratio_normal - 0.8).argmin()
c_closest_to_08 = rho_range[index_closest_to_08]

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(c_range, R2_R1_ratio_normal, label=r'$r_{\mathcal{R}}(A)$ for truncated normal', color='blue')

# Plot the modified axhline and axvline
plt.plot([0, c_closest_to_08], [0.8, 0.8], color='gray', linestyle='--')  # Horizontal line
plt.plot([c_closest_to_08, c_closest_to_08], [0, 0.8], color='gray', linestyle='--')  # Vertical line

plt.scatter([c_closest_to_08], [0.8], color='gray', zorder=5)  # Mark the intersection

# Calculate offsets
xlim = plt.xlim()  # Get the current x-axis limits
ylim = plt.ylim()  # Get the current y-axis limits
x_offset = (xlim[1] - xlim[0]) * 0.02  # Calculate x offset as a percentage of the x-axis range
y_offset = (ylim[1] - ylim[0]) * 0.02  # Calculate y offset as a percentage of the y-axis range

# Place the text left (with x_offset) and above (with y_offset) the scatter point
plt.text(c_closest_to_08 - x_offset, 0.8 + y_offset, f'({c_closest_to_08:.3f}, 0.8)',
         verticalalignment='bottom', horizontalalignment='right', fontsize=25)

plt.title(r'Plot of $r_{\mathcal{R}}(A)$ for $\rho = 0.8$, $\alpha = 0.5$', fontsize=25)
plt.xlabel(r'$c$', fontsize=25)
plt.ylabel('Representation ratio', fontsize=25)
# plt.legend()
plt.grid(True)

# Adjust tick label sizes
plt.xticks(fontsize=25)
plt.yticks(fontsize=25)
plt.show()
