import numpy as np
import cma
import os
import sys
import pickle
import pandas as pd

# To run this script: python -m experiments.experiment_5

"""
Experiment 5: Optimal Mechanism Design and Population Dynamics 

This experiment solves the Principal's utility maximization problem to determine 
the optimal organizational structure (number of levels L, thresholds mu, and 
reward rate r) across various environmental settings.

Setup:
    - Principal Utility: A weighted trade-off between classifier robustness 
      (accuracy), agent qualification (final attribute), and implementation cost.
    - Optimization: Uses CMA-ES (Covariance Matrix Adaptation Evolution Strategy) 
       to find the optimal (r, mu) vector, as the objective is non-convex.
    - Data: Initializes agents using normalized FICO credit score distributions.

Parts:
    - Part A (Parameter Ablation): Tests how varying gamma (retention), 
      beta (patience), and delta (leg-up) affects the resulting steady-state 
      population attribute and level.
    - Part B (Level Ablation): Specifically examines the impact of the total 
      number of levels (L = 3, 5, 8) on the system's ability to drive improvement.

Goal:
    To demonstrate how a multi-level progression can successfully incentivize 
    populations to reach high status through genuine improvement 
    rather than gaming.

Outputs:
    - results_plot_A.pkl: Trajectory data for sensitivity analysis on gamma, beta, delta.
    - results_plot_B.pkl: Trajectory data for different total level counts (L).
"""

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from functions.value_iter import solve_fixed_point
from functions.get_agent_step import get_agent_step 

DEFAULT_ENV_PARAMS = {
    'beta': 0.8,      # Agent discount factor 
    'gamma': 0.7,     # Attribute retention factor
    'c_plus': 1.0,    # Cost of improvement 
    'c_minus': 0.5,   # Cost of gaming 
    'delta': 0.1,     # Leg-up factor
    'r': 1.0          # Reward rate
}

PRINCIPAL_PARAMS = {
    'alpha': 0.95,    # Principal's discount factor 
    'lambda_': 5.0,   # Weight for qualification 
    'xi': 0.01        # Weight for implementation cost 
}

# Helper functions

def get_data():
    """Loads FICO data normalized to [0, 10]."""
    csv_path = os.path.join('data', 'fico', 'fico.csv')
    if not os.path.exists(csv_path):
        return np.random.uniform(0, 10, 100)
    
    df = pd.read_csv(csv_path)
    raw_scores = df['fico_score'].values.astype(float)
    fico_min, fico_max = 300.0, 850.0
    raw_scores = np.clip(raw_scores, fico_min, fico_max)
    return (raw_scores - fico_min) / (fico_max - fico_min) * 10.0

def simulate_population(W, X_grid, params, initial_attributes, T=20):
    """Simulates population trajectories for levels, attributes, and actions."""
    c_plus = params['c_plus']
    V = W - c_plus * X_grid[np.newaxis, :]
    population_size = len(initial_attributes)
    
    traj_x = np.zeros((population_size, T + 1))
    traj_l = np.zeros((population_size, T + 1), dtype=int)
    traj_a_plus = np.zeros((population_size, T))
    traj_a_minus = np.zeros((population_size, T))
    
    traj_x[:, 0] = initial_attributes
    
    for t in range(T):
        l_curr = traj_l[:, t]
        x_curr = traj_x[:, t]
        
        a_plus, a_minus, l_next, x_next = get_agent_step(
            l_curr, x_curr, V, X_grid, params
        )
        
        traj_a_plus[:, t] = a_plus
        traj_a_minus[:, t] = a_minus
        traj_l[:, t+1] = l_next
        traj_x[:, t+1] = x_next
            
    return traj_l, traj_x, traj_a_plus, traj_a_minus


def objective_function(decision_vector, current_params, initial_data, X_grid):
    r_curr = decision_vector[0]
    mu_curr = np.sort(decision_vector[1:])  
    
    p = current_params.copy()
    p['r'] = r_curr
    p['mu_list'] = mu_curr
    
    # 1. Solve Agent Policy
    try:
        W_star = solve_fixed_point(p, X_grid)
    except:
        return 1e9  # Penalty for solver failure
    
    # 2. Simulate
    traj_l, traj_x, traj_a_plus, traj_a_minus = simulate_population(
        W_star, X_grid, p, initial_data
    )
    
    # 3. Compute Utility
    alpha = PRINCIPAL_PARAMS['alpha']
    lambda_ = PRINCIPAL_PARAMS['lambda_']
    xi = PRINCIPAL_PARAMS['xi']
    T = traj_a_plus.shape[1]
    
    discount_factors = alpha ** np.arange(T)
    accuracy_scores = np.mean(traj_a_minus < 1e-5, axis=0)
    qualification_scores = np.mean(traj_x[:, 1:], axis=0)
    cost_scores = xi * r_curr * np.mean(traj_l[:, 1:], axis=0)
    
    step_utilities = accuracy_scores + lambda_ * qualification_scores - cost_scores
    total_utility = np.sum(discount_factors * step_utilities).item()
    
    return -total_utility


def run_optimization(L, env_params, initial_attributes, X_grid):
    """Optimizes (r, mu) using CMA-ES for a given number of levels L."""
    r_init = 1.0
    mu_init = np.linspace(0.1, 10, L)
    x0 = np.concatenate(([r_init], mu_init))
    sigma0 = 0.5
    
    fit_func = lambda x: objective_function(x, env_params, initial_attributes, X_grid)
    
    opts = {'verb_disp': 0, 'maxiter': 40, 'bounds': [[0.0]*(L+1), [None]*(L+1)]}
    
    es = cma.CMAEvolutionStrategy(x0, sigma0, opts)
    es.optimize(fit_func)
    
    best_vector = es.result.xbest
    return best_vector[0], np.sort(best_vector[1:])


def optimize_single_case(env_params, initial_data, X_grid):
    current_params = env_params.copy()
    
    best_L_global = 0
    best_util_global = -1e9
    best_r_global = 0
    best_mu_global = 0
    
    # Sweep L from 2 to 8 
    L_values = [2, 3, 4, 5, 6, 7, 8]
    
    for L in L_values:
        print(f"   Optimizing L={L}...", end='', flush=True)
        
        # Init Guess
        x0 = np.concatenate(([1.0], np.linspace(1.0, 9.0, L)))
        sigma0 = 0.5
        
        # Wrapper
        fit_func = lambda x: objective_function(x, current_params, initial_data,
                                                X_grid)
        
        # CMA-ES
        opts = {
            'verbose': -9,  # Silent
            'maxiter': 30,  # Fast convergence check
            'popsize': 10,
            'bounds': [[0.0] * (L + 1), [None] * (L + 1)]
        }
        
        try:
            es = cma.CMAEvolutionStrategy(x0, sigma0, opts)
            es.optimize(fit_func)
            
            curr_util = -es.result.fbest
            curr_r = es.result.xbest[0]
            curr_mus = np.sort(es.result.xbest[1:])
            
            print(f" Util: {curr_util:.2f}")
            
            if curr_util > best_util_global:
                best_util_global = curr_util
                best_L_global = L
                best_r_global = curr_r
                best_mu_global = curr_mus  # Threshold Sequence
        
        except Exception as e:
            print(f" Failed ({e})")
    
    return {
        'Costs (c+, c-)': f"({current_params['c_plus']}, {current_params['c_minus']})",
        'Opt Levels (L*)': best_L_global,
        'Reward (r*)': best_r_global,
        'Mu list (mu*)': best_mu_global,
        'Utility': best_util_global
    }

# Main experiment

def run_experiment_5():
    X_grid = np.linspace(0, 15, 151)
    initial_attributes = get_data()
    output_dir = os.path.join('data', 'exp_5')
    os.makedirs(output_dir, exist_ok=True)
    
    # --- Part A: Parameter Ablation (Fixed L=5) ---
    print("--- Starting Plot A: Parameter Ablation ---")
    ablation_config = {
        'gamma': [0.4, 0.5, 0.7, 0.8, 0.9],
        'beta': [0.4, 0.5, 0.7, 0.8, 0.9],
        'delta': [0.0, 0.1, 0.2, 0.3, 0.5]
    }
    
    results_A = {} 
    L_fixed = 5
    
    for param_name, values in ablation_config.items():
        results_A[param_name] = {}
        for val in values:
            print(f"\nTesting {param_name} = {val}...")
            current_env = DEFAULT_ENV_PARAMS.copy()
            current_env[param_name] = val
            
            result = optimize_single_case(current_env, initial_attributes,
                                         X_grid)
            
            current_env['r'] = result['Reward (r*)']
            current_env['mu_list'] = result['Mu list (mu*)']
            W_star = solve_fixed_point(current_env, X_grid)
            
            traj_l, traj_x, traj_ap, traj_am = simulate_population(W_star, X_grid, current_env, initial_attributes, T=20)
            
            # Print Final Population State
            final_mean_x = np.mean(traj_x[:, -1])
            final_mean_l = np.mean(traj_l[:, -1])
            print(f"  > Final Mean Attribute (x): {final_mean_x:.4f}")
            print(f"  > Final Mean Level (l):    {final_mean_l:.4f}")
            

            results_A[param_name][val] = {
                'traj_x': traj_x,
                'traj_l': traj_l,  
                'traj_ap': traj_ap,
                'traj_am': traj_am,
                'opt_r': result['Reward (r*)'],
                'opt_mus': result['Mu list (mu*)']
            }

    # --- Part B: Level Ablation ---
    print("\n--- Starting Plot B: Level Ablation ---")
    L_values = [3, 5, 8]
    results_B = {}
    
    for L in L_values:
        print(f"\n[Testing L = {L}...")
        best_r, best_mus = run_optimization(L, DEFAULT_ENV_PARAMS, initial_attributes, X_grid)
        
        print(f"  > Optimized r: {best_r:.4f}")
        print(f"  > Thresholds (mu): {', '.join([f'{m:.3f}' for m in best_mus])}")
        
        current_env = DEFAULT_ENV_PARAMS.copy()
        current_env['r'] = best_r
        current_env['mu_list'] = best_mus
        W_star = solve_fixed_point(current_env, X_grid)
        
        traj_l, traj_x, traj_ap, traj_am = simulate_population(W_star, X_grid, current_env, initial_attributes, T=20)
        
        final_mean_x = np.mean(traj_x[:, -1])
        print(f"  > Final Mean Attribute (x): {final_mean_x:.4f}")
        print(f"  > Final Mean Level (l):    {np.mean(traj_l[:, -1]):.4f}")
        

        results_B[L] = {
            'traj_x': traj_x,
            'traj_l': traj_l, 
            'opt_r': best_r,
            'opt_mus': best_mus,
            'traj_x': traj_x, 'traj_l': traj_l, 'opt_r': best_r, 'opt_mus': best_mus
        }

    # Save results
    with open(os.path.join(output_dir, 'results_plot_A.pkl'), 'wb') as f: pickle.dump(results_A, f)
    with open(os.path.join(output_dir, 'results_plot_B.pkl'), 'wb') as f: pickle.dump(results_B, f)
    print(f"\nResults saved to {output_dir}")

if __name__ == "__main__":
    run_experiment_5()