import numpy as np
import os
import sys
import pickle
from joblib import Parallel, delayed

# To run this script, make sure you are in the project root directory and run:
# python -m experiments.experiment_2

"""
Experiment 2: How the interplay between farsightedness (beta) and retention (gamma) 
determines the limits of incentivizability

We run the Greedy Algorithm on a 2D grid of (beta, gamma) and measure:
    1. Max Levels: The number of levels constructed before the algorithm gets stuck.
    2. Max Attribute: The value of the highest threshold reached.

Setup:
    - We set a high 'max_L' (50). If the algorithm stops early, it means
      it is mathematically impossible to incentivize a higher level.
    - Parallel execution is used to speed up the 2D grid sweep.

Parameters (Fixed):
    - c_plus: 1.0
    - c_minus: 0.7
    - r: 1.0
    - delta: 0.0

Output:
    - Saves a dictionary of results to: organized/data/exp_2/exp_2_data.pkl
"""

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from functions.greedy_search import find_optimal_thresholds_greedy

def solve_single_cell(beta, gamma, base_params):
    """
    Helper function to solve for a single (beta, gamma) pair.
    Designed to be picklable for parallel execution.
    """
    p = base_params.copy()
    p['beta'] = beta
    p['gamma'] = gamma
    
    # Run Greedy Search
    # target_M=None -> run as high as possible
    # max_L=50 -> cap at 50 levels (if infinite feasibility)
    greedy_mus = find_optimal_thresholds_greedy(
        params=p, 
        target_M=None, 
        max_L=50, 
        search_cap=100.0
    )
    
    # Metric 1: Number of Levels (Length - 1 because index 0 is 0.0)
    num_levels = len(greedy_mus) - 1
    
    # Metric 2: Max Attribute (Last element)
    max_attr = greedy_mus[-1]
    
    return num_levels, max_attr

def run_2d_sweep():
    # Base Parameters
    base_params = {
        'c_plus': 1.0,
        'c_minus': 0.7,
        'r': 1.0,
        'delta': 0.0,
        'mu_list': []
    }
    
    # 2D Grid Definition
    beta_values = np.linspace(0.4, 0.99, 20)
    gamma_values = np.linspace(0.4, 0.99, 20)
    
    print(f"--- Starting Experiment 2: 2D Grid Sweep ({len(beta_values)}x{len(gamma_values)}) ---")
    print(f"Running on {os.cpu_count()} cores...")

    # Prepare inputs for parallel execution
    tasks = [
        (b, g) 
        for b in beta_values 
        for g in gamma_values
    ]
    
    # Execute in parallel using joblib
    # n_jobs=-1 uses all available cores
    results_flat = Parallel(n_jobs=-1, verbose=5)(
        delayed(solve_single_cell)(b, g, base_params) 
        for b, g in tasks
    )
    
    # Reshape results back into 2D matrices
    # The order of tasks was: loop beta (outer), loop gamma (inner)
    # So we reshape to (len(beta), len(gamma))
    
    max_levels_flat = [res[0] for res in results_flat]
    max_attr_flat = [res[1] for res in results_flat]
    
    max_levels = np.array(max_levels_flat).reshape(len(beta_values), len(gamma_values))
    max_attribute = np.array(max_attr_flat).reshape(len(beta_values), len(gamma_values))
    
    results = {
        'beta_vals': beta_values,
        'gamma_vals': gamma_values,
        'max_levels': max_levels,
        'max_attribute': max_attribute
    }
            
    print("\nExperiment Complete.")
    return results

def save_data(full_results):
    data_dir = os.path.join('data', 'exp_2')
    if not os.path.exists(data_dir): os.makedirs(data_dir)
    
    file_path = os.path.join(data_dir, 'exp_2_data.pkl')
    with open(file_path, 'wb') as f:
        pickle.dump(full_results, f)
    print(f"Data saved to {file_path}")

if __name__ == "__main__":
    results = run_2d_sweep()
    save_data(results)