import numpy as np
import os
import sys
import pickle

# To run this script: python -m experiments.experiment_1

"""
Experiment 1: Parameter Sensitivity Sweeps for Threshold Sequence

This experiment performs a sensitivity analysis on the core model parameters 
to observe how they influence the formation of the optimal threshold sequence.

Setup:
We sweep across ranges for retention (gamma), patience (beta), improvement cost (c_plus),
and gaming cost (c_minus). For each parameter value, the greedy search algorithm 
determines the sequence of optimal thresholds [mu_1, mu_2, ..., mu_5].

Goal:
To identify which parameters have the most significant impact and the maximum attribute 
values achievable by agents.

Outputs:
- A directory 'data/exp_1/' created if it does not exist.
- A pickle file 'exp_1_data.pkl' containing a dictionary with results for:
    - x_vals: The array of parameter values swept.
    - mu_2 through mu_5: Arrays of the resulting threshold values for each sweep point.
"""

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from functions.greedy_search import find_optimal_thresholds_greedy

def run_sensitivity_sweeps():
    base_params = {
        'beta': 0.8,
        'gamma': 0.9,
        'c_plus': 1.0,
        'c_minus': 0.7,
        'r': 1.0,
        'delta': 0.0,
        'mu_list': []
    }
    
    def collect_sweep_data(param_name, values):
        print(f"\n--- Sweeping {param_name} ---")
        
        # We track thresholds mu_2 through mu_5
        mu_data = {2: [], 3: [], 4: [], 5: []}
        
        for v in values:
            p = base_params.copy()
            p[param_name] = v
            
            # Run Greedy Search capped at L=5
            greedy_mus = find_optimal_thresholds_greedy(
                params=p, 
                target_M=None, 
                max_L=5, 
                search_cap=60.0
            )
            
            # Store data for mu_2, mu_3, mu_4, mu_5
            for k in [2, 3, 4, 5]:
                idx = k - 1
                if len(greedy_mus) > idx:
                    mu_data[k].append(greedy_mus[idx])
                else:
                    mu_data[k].append(np.nan) 
        
        return {
            'x_vals': values,
            'mu_2': np.array(mu_data[2]),
            'mu_3': np.array(mu_data[3]),
            'mu_4': np.array(mu_data[4]),
            'mu_5': np.array(mu_data[5])
        }

    # --- Define the Sweeps ---
    
    # 1. Retention (Gamma)
    results_gamma = collect_sweep_data('gamma', np.linspace(0.6, 0.99, 20))
    
    # 2. Patience (Beta)
    results_beta = collect_sweep_data('beta', np.linspace(0.5, 0.99, 20))
    
    # 3. Improvement Cost (c_plus)
    results_cplus = collect_sweep_data('c_plus', np.linspace(0.8, 1.5, 20))

    # 4. Gaming Cost (c_minus) 
    results_cminus = collect_sweep_data('c_minus', np.linspace(0.2, 1.2, 20))

    return {
        'gamma': results_gamma,
        'beta': results_beta,
        'c_plus': results_cplus,
        'c_minus': results_cminus 
    }

def save_data(full_results):
    data_dir = os.path.join('data', 'exp_1')
    if not os.path.exists(data_dir): os.makedirs(data_dir)
    
    file_path = os.path.join(data_dir, 'exp_1_data.pkl')
    with open(file_path, 'wb') as f:
        pickle.dump(full_results, f)
    print(f"Data saved to {file_path}")

if __name__ == "__main__":
    results = run_sensitivity_sweeps()
    save_data(results)