import os
import pandas as pd
import pymc as pm
import arviz as az
import numpy as np
import pytensor.tensor as pt
from sklearn.model_selection import train_test_split

FILENAME = ""
MODEL_NAME = ""

try:
    df = pd.read_csv(FILENAME)
    print(f"Successfully loaded '{FILENAME}'.")
except FileNotFoundError:
    print(f"--- ERROR: Input file '{FILENAME}' not found. ---")
    exit()

df['chose_A'] = (df['llm_choice'] == 'A').astype(int)
df.dropna(subset=df.columns, inplace=True)

train_df, test_df = train_test_split(df, test_size=0.25, random_state=42)
print(f"Data split into {len(train_df)} training samples and {len(test_df)} testing samples.")

def prepare_data_dict(subset_df):
    epsilon = 1e-6
    data_dict = {
        'r1_a': np.maximum(subset_df['reward1_a'].values, epsilon),
        'r2_a': np.maximum(subset_df['reward2_a'].values, epsilon),
        'p_a':  subset_df['p_a'].values,
        'r1_b': np.maximum(subset_df['reward1_b'].values, epsilon),
        'r2_b': np.maximum(subset_df['reward2_b'].values, epsilon),
        'p_b':  subset_df['p_b'].values,
        'choice': subset_df['chose_A'].values,
    }
    return data_dict

train_data = prepare_data_dict(train_df)
test_data = prepare_data_dict(test_df)


eps = 1e-6

def linear_utility(x):
    return x

def power_utility(x, alpha):
    alpha_safe = alpha + eps
    return pt.power(x, alpha_safe)

def quadratic_utility(x, a, b):
    return a * x - b * pt.power(x, 2)

def crra_utility(x, gamma):
    gamma_safe = gamma + 0.0
    is_one = pt.abs(gamma_safe - 1.0) < 1e-6
    return pt.switch(is_one, pt.log(x), (pt.power(x, 1 - gamma_safe) - 1) / (1 - gamma_safe))

def cara_utility(x, alpha):
    alpha_safe = pt.abs(alpha) + eps
    x_scaled = x / 250.0  # scale x
    return -pt.exp(-x_scaled / alpha_safe)

def isoelastic_utility(x, theta, lam=1.0):
    theta_safe = theta + 0.0
    is_one = pt.abs(theta_safe - 1.0) < 1e-6
    scaled = lam * x
    return pt.switch(is_one, pt.log(scaled), (pt.power(scaled, 1 - theta_safe) - 1) / (1 - theta_safe))

def prospect_theory_value(x, alpha, beta, lam, reference_point=0.0):
    gain = pt.power(pt.maximum(x - reference_point, 0.0) + eps, alpha)
    loss = -lam * pt.power(pt.maximum(reference_point - x, 0.0) + eps, beta)
    return gain + loss

def hara_utility(x, a, b, gamma):
    gamma_safe = gamma + 0.0
    term = a + b * x
    term_pos = pt.maximum(term, eps)
    return ((1 - gamma_safe) / (gamma_safe + eps)) * (pt.power(term_pos, gamma_safe) - 1.0)

def expo_power_utility_saha(x, alpha, theta):
    alpha_safe = pt.abs(alpha) + eps
    exponent_term = -alpha_safe * pt.power(x + eps, 1 - theta)
    return (1 - pt.exp(exponent_term)) / alpha_safe

# NEW: Advanced piecewise utility function with convex/concave regions
def piecewise_utility(x, c1, c2, alpha1, alpha2, alpha3):
    y1 = pt.power(c1, alpha1)
    y2 = y1 + (pt.power(c2, alpha2) - pt.power(c1, alpha2))
    
    piece1 = pt.power(x, alpha1)
    piece2 = y1 + (pt.power(x, alpha2) - pt.power(c1, alpha2))
    piece3 = y2 + (pt.power(x, alpha3) - pt.power(c2, alpha3))
    
    res = pt.switch(x < c1, piece1, 
                    pt.switch(x < c2, piece2, piece3))
    return res

def epstein_zin_utility(lottery_rewards, lottery_probs, alpha, psi, beta_disc):
    rewards = pt.as_tensor_variable(lottery_rewards)
    probs = pt.as_tensor_variable(lottery_probs)
    alpha_s = alpha + 0.0
    psi_s = psi + eps
    beta_s = pt.clip(beta_disc, 0.0, 0.999)
    exp_term = pt.sum(probs * pt.power(rewards + eps, 1.0 - alpha_s))
    exp_term_pos = pt.maximum(exp_term, eps)
    numerator = (1.0 - 1.0/psi_s)
    denom = (1.0 - alpha_s) + eps
    inner = pt.power(exp_term_pos, numerator / denom)
    c_term = eps
    aggregated = pt.power((1.0 - beta_s) * pt.power(c_term, numerator) + beta_s * inner, 1.0 / numerator)
    return aggregated

# --- Probability weighting functions (unchanged) ---
def prelec_weighting(p, gamma):
    return pt.exp(-pt.power(-pt.log(p + eps), gamma))

def gonzalez_wu_weighting(p, delta, gamma):
    num = pt.power(p + eps, delta)
    denom = pt.power(num + pt.power(1 - p + eps, delta), 1 / gamma)
    return num / denom

model_results = {}
SAMPLING_KWARGS = {'draws': 3000, 'tune': 1500, 'chains': 6, 'cores': 6, 'return_inferencedata': True, 'target_accept': 0.97}

# --- Define the utility models to fit ---
utility_models = {
    'Linear':           ('linear_utility', ['beta_sensitivity']),
    'Power':            ('power_utility', ['alpha', 'beta_sensitivity']),
    'Quadratic':        ('quadratic_utility', ['a', 'b', 'beta_sensitivity']),
    'CRRA':             ('crra_utility', ['gamma', 'beta_sensitivity']),
    'CARA':             ('cara_utility', ['alpha', 'beta_sensitivity']),
    'HARA':             ('hara_utility', ['a', 'b', 'gamma', 'beta_sensitivity']),
    'ExpoPower_Saha':   ('expo_power_utility_saha', ['alpha', 'theta', 'beta_sensitivity']),
    'ProspectTheory':   ('prospect_theory_value', ['alpha', 'beta', 'lam', 'beta_sensitivity']),
    'EpsteinZin':       ('epstein_zin_utility', ['alpha', 'psi', 'beta_disc', 'beta_sensitivity']),
    'Piecewise':        ('piecewise_utility', ['c1', 'c2_delta', 'alpha1', 'alpha2', 'alpha3', 'beta_sensitivity']),
}

max_reward = train_df[['reward1_a', 'reward2_a', 'reward1_b', 'reward2_b']].max().max()

for name, (u_func_str, params) in utility_models.items():
    u_func = globals()[u_func_str]
    print(f"\n--- Fitting {name} Utility Model ---")
    with pm.Model() as model:
        priors = {}
        # Define priors for all parameters
        for p in params:
            if p == 'beta_sensitivity':
                priors[p] = pm.HalfNormal(p, sigma=2.0)
            elif p == 'lam':
                priors[p] = pm.Normal(p, mu=2.0, sigma=1.0)
            elif p == 'a':
                priors[p] = pm.Normal(p, mu=1.0, sigma=1.0)
            elif p == 'b':
                priors[p] = pm.HalfNormal(p, sigma=1.0)
            elif p in ('gamma', 'alpha', 'theta', 'delta'):
                priors[p] = pm.HalfNormal(p, sigma=1.0)
            elif p == 'psi':
                priors[p] = pm.Normal(p, mu=1.0, sigma=0.5)
            elif p == 'beta_disc':
                priors[p] = pm.Beta(p, alpha=2.0, beta=2.0)
            elif p == 'c1':
                priors[p] = pm.TruncatedNormal(p, mu=max_reward * 0.25, sigma=max_reward * 0.1, lower=eps, upper=max_reward)
            elif p == 'c2_delta':
                priors[p] = pm.HalfNormal(p, sigma=max_reward * 0.2)
            elif p == 'alpha1' or p == 'alpha3':
                priors[p] = pm.TruncatedNormal(p, mu=0.7, sigma=0.3, lower=eps, upper=1.0)
            elif p == 'alpha2':
                priors[p] = pm.TruncatedNormal(p, mu=1.3, sigma=0.3, lower=1.0)
            else:
                priors[p] = pm.Normal(p, mu=0.8, sigma=0.5)

        utility_params = {k: priors[k] for k in params if k != 'beta_sensitivity'}

        if name == 'ProspectTheory':
            U_A = train_data['p_a'] * u_func(train_data['r1_a'], **utility_params, reference_point=0.0) + \
                  (1 - train_data['p_a']) * u_func(train_data['r2_a'], **utility_params, reference_point=0.0)
            U_B = train_data['p_b'] * u_func(train_data['r1_b'], **utility_params, reference_point=0.0) + \
                  (1 - train_data['p_b']) * u_func(train_data['r2_b'], **utility_params, reference_point=0.0)
        elif name == 'EpsteinZin':
            U_A = epstein_zin_utility(
                [train_data['r1_a'], train_data['r2_a']], [train_data['p_a'], 1 - train_data['p_a']],
                **utility_params
            )
            U_B = epstein_zin_utility(
                [train_data['r1_b'], train_data['r2_b']], [train_data['p_b'], 1 - train_data['p_b']],
                **utility_params
            )
        elif name == 'Piecewise':
            c2_ordered = pm.Deterministic('c2', priors['c1'] + priors['c2_delta'])
            # Pass the ordered changepoints to the function
            piecewise_params = {
                'c1': priors['c1'], 'c2': c2_ordered,
                'alpha1': priors['alpha1'], 'alpha2': priors['alpha2'], 'alpha3': priors['alpha3']
            }
            U_A = train_data['p_a'] * u_func(train_data['r1_a'], **piecewise_params) + \
                  (1 - train_data['p_a']) * u_func(train_data['r2_a'], **piecewise_params)
            U_B = train_data['p_b'] * u_func(train_data['r1_b'], **piecewise_params) + \
                  (1 - train_data['p_b']) * u_func(train_data['r2_b'], **piecewise_params)
        else:
            U_A = train_data['p_a'] * u_func(train_data['r1_a'], **utility_params) + \
                  (1 - train_data['p_a']) * u_func(train_data['r2_a'], **utility_params)
            U_B = train_data['p_b'] * u_func(train_data['r1_b'], **utility_params) + \
                  (1 - train_data['p_b']) * u_func(train_data['r2_b'], **utility_params)

        p_choose_A = pm.math.sigmoid(priors['beta_sensitivity'] * (U_A - U_B))
        y_obs = pm.Bernoulli('y_obs', p=p_choose_A, observed=train_data['choice'])

        trace = pm.sample(**SAMPLING_KWARGS)
        
        if name == 'Piecewise':
            params.append('c2') 
            params.remove('c2_delta')

        model_results[name] = {'trace': trace, 'utility_func': u_func, 'param_names': params}



print("\n--- Fitting Prelec Weighting Model ---")
with pm.Model() as prelec_model:
    gamma = pm.Normal('gamma', mu=0.7, sigma=0.3)
    beta_s = pm.HalfNormal('beta_sensitivity', sigma=2.0)
    w_p_a = prelec_weighting(train_data['p_a'], gamma)
    w_p_b = prelec_weighting(train_data['p_b'], gamma)
    U_A = w_p_a * train_data['r1_a'] + (1-w_p_a) * train_data['r2_a']
    U_B = w_p_b * train_data['r1_b'] + (1-w_p_b) * train_data['r2_b']
    p_choose_A = pm.math.sigmoid(beta_s * (U_A - U_B))
    y_obs = pm.Bernoulli('y_obs', p=p_choose_A, observed=train_data['choice'])
    trace = pm.sample(**SAMPLING_KWARGS)
    model_results['Prelec'] = {'trace': trace, 'weight_func': prelec_weighting, 'param_names': ['gamma','beta_sensitivity']}

print("\n--- Fitting Gonzalez-Wu Weighting Model ---")
with pm.Model() as gw_model:
    delta = pm.Normal('delta', mu=0.9, sigma=0.3)
    gamma = pm.Normal('gamma', mu=0.7, sigma=0.3)
    beta_s = pm.HalfNormal('beta_sensitivity', sigma=2.0)
    w_p_a = gonzalez_wu_weighting(train_data['p_a'], delta, gamma)
    w_p_b = gonzalez_wu_weighting(train_data['p_b'], delta, gamma)
    U_A = w_p_a * train_data['r1_a'] + (1-w_p_a) * train_data['r2_a']
    U_B = w_p_b * train_data['r1_b'] + (1-w_p_b) * train_data['r2_b']
    p_choose_A = pm.math.sigmoid(beta_s * (U_A - U_B))
    y_obs = pm.Bernoulli('y_obs', p=p_choose_A, observed=train_data['choice'])
    trace = pm.sample(**SAMPLING_KWARGS)
    model_results['GonzalezWu'] = {'trace': trace, 'weight_func': gonzalez_wu_weighting, 'param_names': ['delta','gamma','beta_sensitivity']}


print("\n--- Model Verification on Test Set ---")
accuracies = {}

for name, result in model_results.items():
    summary = az.summary(result['trace'], var_names=result['param_names'])
    param_means = summary['mean'].to_dict()

    if 'utility_func' in result:
        u_func = result['utility_func']
        pars = {p: param_means[p] for p in result['param_names'] if p != 'beta_sensitivity'}
        beta_mean = param_means['beta_sensitivity']

        if name == 'ProspectTheory':
            U_A_calc = test_data['p_a'] * u_func(test_data['r1_a'], **pars, reference_point=0.0) + \
                       (1-test_data['p_a']) * u_func(test_data['r2_a'], **pars, reference_point=0.0)
            U_B_calc = test_data['p_b'] * u_func(test_data['r1_b'], **pars, reference_point=0.0) + \
                       (1-test_data['p_b']) * u_func(test_data['r2_b'], **pars, reference_point=0.0)
        elif name == 'EpsteinZin':
            U_A_calc = epstein_zin_utility(
                [test_data['r1_a'], test_data['r2_a']], [test_data['p_a'], 1 - test_data['p_a']],
                **pars
            )
            U_B_calc = epstein_zin_utility(
                [test_data['r1_b'], test_data['r2_b']], [test_data['p_b'], 1 - test_data['p_b']],
                **pars
            )
        elif name == 'Piecewise':
            piecewise_pars_test = {
                'c1': pars['c1'], 'c2': pars['c2'],
                'alpha1': pars['alpha1'], 'alpha2': pars['alpha2'], 'alpha3': pars['alpha3']
            }
            U_A_calc = test_data['p_a'] * u_func(test_data['r1_a'], **piecewise_pars_test) + \
                       (1-test_data['p_a']) * u_func(test_data['r2_a'], **piecewise_pars_test)
            U_B_calc = test_data['p_b'] * u_func(test_data['r1_b'], **piecewise_pars_test) + \
                       (1-test_data['p_b']) * u_func(test_data['r2_b'], **piecewise_pars_test)
        else:
            U_A_calc = test_data['p_a'] * u_func(test_data['r1_a'], **pars) + \
                       (1-test_data['p_a']) * u_func(test_data['r2_a'], **pars)
            U_B_calc = test_data['p_b'] * u_func(test_data['r1_b'], **pars) + \
                       (1-test_data['p_b']) * u_func(test_data['r2_b'], **pars)

    elif 'weight_func' in result:
        beta_mean = param_means['beta_sensitivity']
        if name == 'Prelec':
            gamma_mean = param_means['gamma']
            w_p_a_test = prelec_weighting(test_data['p_a'], gamma_mean)
            w_p_b_test = prelec_weighting(test_data['p_b'], gamma_mean)
        elif name == 'GonzalezWu':
            delta_mean = param_means['delta']
            gamma_mean = param_means['gamma']
            w_p_a_test = gonzalez_wu_weighting(test_data['p_a'], delta_mean, gamma_mean)
            w_p_b_test = gonzalez_wu_weighting(test_data['p_b'], delta_mean, gamma_mean)
        U_A_calc = w_p_a_test * test_data['r1_a'] + (1-w_p_a_test) * test_data['r2_a']
        U_B_calc = w_p_b_test * test_data['r1_b'] + (1-w_p_b_test) * test_data['r2_b']

    if isinstance(U_A_calc, pt.TensorVariable):
        U_A_test = U_A_calc.eval()
    else:
        U_A_test = U_A_calc
    if isinstance(U_B_calc, pt.TensorVariable):
        U_B_test = U_B_calc.eval()
    else:
        U_B_test = U_B_calc

    utility_difference = U_A_test - U_B_test
    clipped_argument = np.clip(-beta_mean * utility_difference, -700, 700)
    p_test = 1 / (1 + np.exp(clipped_argument))
    
    predictions = (p_test > 0.5).astype(int)
    accuracies[name] = np.mean(predictions == test_data['choice'])

print("\n\n" + "="*50)
print(" " * 10 + "COMPREHENSIVE MODEL FITTING REPORT")
print("="*50)
print(f"\nModel: {MODEL_NAME}")

print("\n--- Prediction Accuracy on Unseen Data (25% Test Set) ---")
sorted_accuracies = sorted(accuracies.items(), key=lambda item: item[1], reverse=True)
for model_name, acc in sorted_accuracies:
    print(f"  - {model_name:<25} Accuracy: {acc:.2%}")

output_dir = "./fit_utility"
os.makedirs(output_dir, exist_ok=True)

acc_df = pd.DataFrame(sorted_accuracies, columns=["Model", "Accuracy"])
acc_df.to_csv(os.path.join(output_dir, "utility_model_accuracies.csv"), index=False)
print(f"\nAccuracies saved to '{os.path.join(output_dir, 'utility_model_accuracies.csv')}'.")

with open(os.path.join(output_dir, "utility_model_report.txt"), "w") as f:
    f.write("="*50 + "\n")
    f.write("COMPREHENSIVE MODEL FITTING REPORT\n")
    f.write("="*50 + "\n")
    f.write(f"\nModel: {MODEL_NAME}\n")
    f.write("\n--- Prediction Accuracy on Unseen Data (25% Test Set) ---\n")
    for model_name, acc in sorted_accuracies:
        f.write(f"  - {model_name:<25} Accuracy: {acc:.2%}\n")
    f.write("\n--- Inferred Parameter Summaries ---\n")
    for name, result in sorted(model_results.items()):
        f.write(f"\n{name} Model:\n")
        summary_str = az.summary(result['trace'], var_names=result['param_names']).to_string()
        f.write(summary_str + "\n")
    if accuracies:
        best_model = max(accuracies, key=accuracies.get)
        f.write("\n--- Interpretation ---\n")
        f.write(f"The model with the highest accuracy ('{best_model}') best captures this LLM's choice behavior.\n")

print(f"Full report saved to '{os.path.join(output_dir, 'utility_model_report.txt')}'.")