"""
FRACTAL: Memory Measure Comparison Visualization
Generates Figure comparing fractional measure with LegS, LagT, LegT
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import matplotlib.patches as mpatches

# Set publication-quality settings
rcParams['font.family'] = 'serif'
rcParams['font.size'] = 11
rcParams['axes.labelsize'] = 12
rcParams['axes.titlesize'] = 12
rcParams['xtick.labelsize'] = 10
rcParams['ytick.labelsize'] = 10
rcParams['legend.fontsize'] = 9
rcParams['figure.dpi'] = 150
rcParams['savefig.dpi'] = 300
rcParams['text.usetex'] = False

# Define color palette (colorblind-friendly)
colors = {
    'LegS': '#1f77b4',      # Blue
    'LagT': '#ff7f0e',      # Orange  
    'LegT': '#2ca02c',      # Green
    'Frac_0.3': '#d62728',  # Red
    'Frac_0.5': '#9467bd',  # Purple
    'Frac_0.7': '#8c564b',  # Brown
    'Frac_0.9': '#e377c2',  # Pink
}

def mu_legs(x, t):
    """LegS: Uniform measure on [0, t]"""
    return np.where((x >= 0) & (x <= t), 1.0 / t, 0.0)

def mu_lagt(x, t, decay=1.0):
    """LagT: Exponential decay measure"""
    return np.where(x <= t, decay * np.exp(-decay * (t - x)), 0.0)

def mu_legt(x, t, theta=0.3):
    """LegT: Sliding window measure on [t-theta, t]"""
    return np.where((x >= t - theta) & (x <= t), 1.0 / theta, 0.0)

def mu_fractional(x, t, alpha):
    """Fractional measure with singularity index alpha"""
    result = np.zeros_like(x)
    valid = (x >= 0) & (x < t)
    if np.any(valid):
        diff = t - x[valid]
        diff = np.maximum(diff, 1e-10)  # Avoid division by zero
        result[valid] = (1 - alpha) / (t ** (1 - alpha)) * (diff ** (-alpha))
    return result

# Create figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))

t = 1.0  # Current time normalized to 1
x = np.linspace(0, t - 1e-4, 1000)

# ============================================
# Left Panel: Compare all measure families
# ============================================
ax1 = axes[0]

# Plot existing measures
ax1.plot(x, mu_legs(x, t), label='LegS (Uniform)', color=colors['LegS'], 
         linewidth=2, linestyle='-')
ax1.plot(x, mu_lagt(x, t, decay=3.0), label='LagT (Exponential)', 
         color=colors['LagT'], linewidth=2, linestyle='--')
ax1.plot(x, mu_legt(x, t, theta=0.3), label='LegT (Window)', 
         color=colors['LegT'], linewidth=2, linestyle=':')

# Plot fractional measures
ax1.plot(x, mu_fractional(x, t, 0.5), label=r'FRACTAL $\alpha$=0.5', 
         color=colors['Frac_0.5'], linewidth=2.5, linestyle='-')

ax1.set_xlabel(r'Relative time $x/t$')
ax1.set_ylabel(r'Memory weight $\mu^{(t)}(x)$')
ax1.set_title('(a) Comparison of Memory Measures')
ax1.legend(loc='upper left', framealpha=0.9)
ax1.set_xlim([0, 1])
ax1.set_ylim([0, 8])
ax1.grid(True, alpha=0.3)

# Add annotation
ax1.annotate('Singularity\nat present', xy=(0.95, 6), xytext=(0.7, 5.5),
             fontsize=9, arrowprops=dict(arrowstyle='->', color='gray'),
             color='gray')

# ============================================
# Right Panel: Fractional measures with varying alpha
# ============================================
ax2 = axes[1]

alphas = [0.0, 0.3, 0.5, 0.7, 0.9]
alpha_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

for alpha, color in zip(alphas, alpha_colors):
    label = r'$\alpha$ = ' + f'{alpha}'
    if alpha == 0:
        label += ' (LegS)'
    y = mu_fractional(x, t, alpha)
    # Clip for visualization
    y = np.clip(y, 0, 15)
    ax2.plot(x, y, label=label, color=color, linewidth=2)

ax2.set_xlabel(r'Relative time $x/t$')
ax2.set_ylabel(r'Memory weight $\mu^{(t)}(x)$')
ax2.set_title(r'(b) Fractional Measure Family: Effect of $\alpha$')
ax2.legend(loc='upper left', framealpha=0.9)
ax2.set_xlim([0, 1])
ax2.set_ylim([0, 12])
ax2.grid(True, alpha=0.3)

# Add shaded region for "recent history"
ax2.axvspan(0.8, 1.0, alpha=0.1, color='red', label='_nolegend_')
ax2.text(0.85, 11, 'Recent\nHistory', fontsize=8, ha='center', color='darkred')

plt.tight_layout()
plt.savefig('/home/claude/experiments/measure_comparison.pdf', 
            bbox_inches='tight', format='pdf')
plt.savefig('/home/claude/experiments/measure_comparison.png', 
            bbox_inches='tight', format='png')
print("Saved: measure_comparison.pdf/png")

# ============================================
# Additional: Log-scale comparison showing power-law vs exponential
# ============================================
fig2, ax3 = plt.subplots(1, 1, figsize=(6, 4))

# Use log scale to show decay behavior
x_log = np.linspace(0.01, 0.99, 500)
t = 1.0

# Power-law (fractional) vs exponential decay from t
distance_from_present = t - x_log  # How far in the past

# Normalize to show relative decay
frac_05 = mu_fractional(x_log, t, 0.5)
frac_05 = frac_05 / frac_05[0]  # Normalize

exp_decay = mu_lagt(x_log, t, decay=3.0)
exp_decay = exp_decay / exp_decay[-1]  # Normalize to max

ax3.semilogy(distance_from_present[::-1], frac_05[::-1], 
             label=r'FRACTAL ($\alpha$=0.5): Power-law', 
             color=colors['Frac_0.5'], linewidth=2)
ax3.semilogy(distance_from_present[::-1], exp_decay[::-1], 
             label='LagT: Exponential', 
             color=colors['LagT'], linewidth=2, linestyle='--')

ax3.set_xlabel(r'Distance from present $(t - x)$')
ax3.set_ylabel('Relative weight (log scale)')
ax3.set_title('Long-term Retention: Power-law vs Exponential Decay')
ax3.legend(loc='upper right')
ax3.grid(True, alpha=0.3, which='both')

plt.tight_layout()
plt.savefig('/home/claude/experiments/decay_comparison.pdf', 
            bbox_inches='tight', format='pdf')
plt.savefig('/home/claude/experiments/decay_comparison.png', 
            bbox_inches='tight', format='png')
print("Saved: decay_comparison.pdf/png")

plt.close('all')
print("\nAll measure comparison figures generated successfully!")
