import random

import numpy as np
import cma
import os
import sys
import pickle
import pandas as pd
from concurrent.futures import ThreadPoolExecutor

# To run this script: python -m experiments.experiment_4

"""
Experiment 4: Optimal Policy Table

Goal: Generate a table of optimal setups using full 
CMA-ES optimization and agent simulation.

We optimize (L, r, thresholds) for 4 distinct cost settings:
1. (Low c+, High c-)
2. (High c+, High c-)
3. (Low c+, Low c-)
4. (High c+, Low c-)

Outputs: data\exp_4_table.csv
"""

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
random.seed(GLOBAL_SEED)

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from functions.value_iter import solve_fixed_point
from functions.get_agent_step import get_agent_step

CASES = {
    '1_Case_1':      {'c_plus': 0.8, 'c_minus': 0.7},
    '2_Case_2':   {'c_plus': 1.5, 'c_minus': 1.2},
    '3_Case_3': {'c_plus': 0.8, 'c_minus': 0.4},
    '4_Case_4':  {'c_plus': 1.5, 'c_minus': 0.4}
}

# Base Environment Params
BASE_PARAMS = {
    'beta': 0.8,
    'gamma': 0.8,      
    'delta': 0.01,     
    'r': 1.0       
}

# Principal Params
PRINCIPAL_PARAMS = {
    'alpha': 0.95,
    'lambda_': 5.0, 
    'xi': 0.01          
}

# Fico data
def get_data():
    csv_path = os.path.join('data', 'fico', 'fico.csv')
    if not os.path.exists(csv_path):
        return np.random.uniform(0, 10, 100)
    df = pd.read_csv(csv_path)
    raw_scores = df['fico_score'].values.astype(float)
    fico_min, fico_max = 300.0, 850.0
    raw_scores = np.clip(raw_scores, fico_min, fico_max)
    return (raw_scores - fico_min) / (fico_max - fico_min) * 10.0

def simulate_population(W, X_grid, params, initial_attributes, T=20):
    c_plus = params['c_plus']
    V = W - c_plus * X_grid[np.newaxis, :] # V = W - Cost

    population_size = len(initial_attributes)
    traj_x = np.zeros((population_size, T + 1))
    traj_l = np.zeros((population_size, T + 1), dtype=int)
    traj_a_plus = np.zeros((population_size, T))
    traj_a_minus = np.zeros((population_size, T))
    
    traj_x[:, 0] = initial_attributes
    
    for t in range(T):
        l_curr = traj_l[:, t]
        x_curr = traj_x[:, t]
        
        a_plus, a_minus, l_next, x_next = get_agent_step(
            l_curr, x_curr, V, X_grid, params
        )
        
        traj_a_plus[:, t] = a_plus
        traj_a_minus[:, t] = a_minus
        traj_l[:, t+1] = l_next
        traj_x[:, t+1] = x_next
            
    return traj_l, traj_x, traj_a_plus, traj_a_minus


def objective_function(decision_vector, current_params, initial_data, X_grid):
    r_curr = decision_vector[0]
    mu_curr = np.sort(decision_vector[1:]) 
    
    p = current_params.copy()
    p['r'] = r_curr
    p['mu_list'] = mu_curr
    
    # 1. Solve Agent Policy
    try:
        W_star = solve_fixed_point(p, X_grid)
    except:
        return 1e9 # Penalty for solver failure

    # 2. Simulate
    traj_l, traj_x, traj_a_plus, traj_a_minus = simulate_population(
        W_star, X_grid, p, initial_data
    )
    
    # 3. Compute Utility
    alpha = PRINCIPAL_PARAMS['alpha']
    lambda_ = PRINCIPAL_PARAMS['lambda_']
    xi = PRINCIPAL_PARAMS['xi']
    T = traj_a_plus.shape[1]
    
    total_utility = 0.0
    for t in range(T):
        discount = alpha**t
        accuracy_score = np.mean(traj_a_minus[:, t] < 1e-5)
        qualification_score = np.mean(traj_x[:, t+1])
        cost_score = xi * r_curr * np.mean(traj_l[:, t+1])
        
        step_util = accuracy_score + lambda_ * qualification_score - cost_score
        total_utility += discount * step_util
        
    return -total_utility

def optimize_single_case(case_name, case_params, initial_data, X_grid):
    print(f"\n>>> Processing Case: {case_name})")
    
    # Build Params for this case
    current_params = BASE_PARAMS.copy()
    current_params['c_plus'] = case_params['c_plus']
    current_params['c_minus'] = case_params['c_minus']
    
    best_L_global = 0
    best_util_global = -1e9
    best_r_global = 0
    best_mu_global = 0
    best_sequence = None

    # Sweep L from 2 to 8
    L_values = [2, 3, 4, 5, 6, 7, 8]

    def thread_task(L):
        print(f"   Optimizing L={L}...", end='', flush=True)
        
        # Init Guess
        x0 = np.concatenate(([1.0], np.linspace(1.0, 9.0, L)))
        sigma0 = 0.5
        
        # Wrapper
        fit_func = lambda x: objective_function(x, current_params, initial_data, X_grid)
        
        # CMA-ES
        opts = {
            'verbose': -9,  # Silent
            'maxiter': 50,  # Fast convergence check
            'popsize': 100,
            'bounds': [[0.0]*(L+1), [None]*(L+1)],
            'seed': GLOBAL_SEED
        }
        
        try:
            es = cma.CMAEvolutionStrategy(x0, sigma0, opts)
            es.optimize(fit_func)
            
            curr_util = -es.result.fbest
            curr_r = es.result.xbest[0]
            curr_mus = np.sort(es.result.xbest[1:])
            
            print(f" Util: {curr_util:.2f}")

            return L, curr_util, curr_r, curr_mus

        except Exception as e:
            print(f" Failed ({e})")

    with ThreadPoolExecutor(max_workers=8) as executor:
        futures = executor.map(thread_task, L_values)
        for fut in futures:
            L, curr_util, curr_r, curr_mus = fut
            if curr_util > best_util_global:
                best_util_global = curr_util
                best_L_global = L
                best_r_global = curr_r
                best_mu_global = curr_mus[-1] # Max Threshold
                best_sequence = curr_mus

    return {
        'Costs (c+, c-)': f"({case_params['c_plus']}, {case_params['c_minus']})",
        'Opt Levels (L*)': best_L_global,
        'Reward (r*)': best_r_global,
        'Max Qual (mu_L)': best_mu_global,
        'Utility': best_util_global,
        'Sequence': best_sequence,
    }

# Main loop
def run_experiment_4():
    X_grid = np.linspace(0, 12, 121)
    initial_attributes = get_data()
    
    final_results = []
    
    for name, params in CASES.items():
        res = optimize_single_case(name, params, initial_attributes, X_grid)
        final_results.append(res)
    
    # Create DataFrame
    df = pd.DataFrame(final_results)
    
    # Format for display
    print("\n\n Final Optimal Policy Table")
    print(df)
    
    # Save
    data_dir = os.path.join('data', 'exp_4')
    if not os.path.exists(data_dir): os.makedirs(data_dir)
    
    df.to_csv(os.path.join(data_dir, 'exp_4_table.csv'), index=False)
    print(f"\nSaved table to {data_dir}/exp_4_table.csv")

if __name__ == "__main__":
    run_experiment_4()