import json
import pandas as pd
import glob
import statsmodels.formula.api as smf
import numpy as np
from scipy.stats import pearsonr, spearmanr
import warnings

warnings.filterwarnings("ignore") # Ignore warnings for cleaner output

# --- Loading Functions ---
def load_raw_user_data(data_path='data/*.json'):
    user_records = []
    user_files = glob.glob(data_path)
    for file in user_files:
        user_id = file.split('/')[-1] 
        with open(file, 'r') as f:
            try:
                data = json.load(f)
                for entry in data:
                    song_id = entry.get('folder')
                    audio_type = entry.get('type')
                    for model_name, score in entry['scores'].items():
                        user_records.append({
                            'user_id': user_id, 'song': song_id, 'type': audio_type,
                            'model_key': 'gt' if model_name == 'og' else model_name,
                            'rating': score
                        })
            except: continue
    return pd.DataFrame(user_records)

def load_contrastive_model(file_path):
    with open(file_path, 'r') as f:
        model_json = json.load(f)
    model_records = []
    for audio_type, songs in model_json.items():
        for entry in songs:
            for key in ['gt', 'sac', 'stage', 'moises']:
                if key in entry:
                    model_records.append({
                        'song': entry['song'], 'type': audio_type,
                        'model_key': key, 'model_score': entry[key]
                    })
    return pd.DataFrame(model_records)

from scipy.stats import norm

def steiger_test(r12, r13, r23, n):
    """
    r12: corr(human, model_a)
    r13: corr(human, model_b)
    r23: corr(model_a, model_b) -- how similar the models are to each other
    n: sample size
    """
    # Fisher r-to-z transformation
    z12 = np.arctanh(r12)
    z13 = np.arctanh(r13)
    
    # Difference in correlations accounting for dependency
    diff = z12 - z13
    
    # Covariance factor
    r2_avg = (r12**2 + r13**2) / 2
    f = (1 - r23) / (2 * (1 - r2_avg))
    h = (1 - (f * r2_avg)) / (1 - r2_avg)
    
    # Standard error
    se = np.sqrt(2 * (1 - r23) * h / (n - 3))
    
    z_score = diff / se
    p_value = 2 * (1 - norm.cdf(abs(z_score)))
    
    return z_score, p_value

def run_comprehensive_analysis(df_raw, model_a_path, model_b_path):

    model_a_name = model_a_path.split('_')[-1].split('.')[0].upper()
    model_b_name = model_b_path.split('_')[-1].split('.')[0].upper()

    # 1. Load and Merge
    df_a = load_contrastive_model(model_a_path)
    df_a['model_score_std'] = df_a.groupby(['type'])['model_score'].transform(lambda x: (x - x.mean()) / (x.std() + 1e-9))

    df_b = load_contrastive_model(model_b_path)
    df_b['model_score_std'] = df_b.groupby(['type'])['model_score'].transform(lambda x: (x - x.mean()) / (x.std() + 1e-9))

    # Clone df_raw to avoid modifying original
    df_raw = df_raw.copy()

    df_raw['rating_std'] = df_raw.groupby(['user_id', 'type'])['rating'].transform(
        lambda x: (x - x.mean()) / (x.std() + 1e-9)
    )

    # Filter in only the songs that were rated by users
    rated_songs = df_raw['song'].unique()
    df_a = df_a[df_a['song'].isin(rated_songs)]
    df_b = df_b[df_b['song'].isin(rated_songs)]

    df_combined = pd.merge(df_raw, df_a, on=['song', 'type', 'model_key'])
    df_combined = pd.merge(df_combined, df_b, on=['song', 'type', 'model_key'], suffixes=('_a', '_b'))


    df_agg = df_combined.groupby(['song', 'type', 'model_key']).agg({
        'rating_std': 'mean',
        'model_score_std_a': 'first',
        'model_score_std_b': 'first'
    }).reset_index()

    print(f"{'N':<4} | {'Pearson ' + model_a_name:<15} | {'Pearson ' + model_b_name:<15} | {'Steiger test':<12} | {'Spearman ' + model_a_name:<15} | {'Spearman ' + model_b_name:<15}")

    sub_df = df_agg
    
    # N is now the number of unique audio samples evaluated
    n = len(sub_df) 

    s_a, _ = spearmanr(sub_df['rating_std'], sub_df['model_score_std_a'])
    s_b, _ = spearmanr(sub_df['rating_std'], sub_df['model_score_std_b'])

    r_human_a = pearsonr(sub_df['rating_std'], sub_df['model_score_std_a'])[0]
    r_human_b = pearsonr(sub_df['rating_std'], sub_df['model_score_std_b'])[0]
    r_a_b = pearsonr(sub_df['model_score_std_a'], sub_df['model_score_std_b'])[0]
    
    # Now this p_diff is statistically conservative and accurate
    _, p_diff = steiger_test(r_human_a, r_human_b, r_a_b, n)

    print(f"{n:<4} | {r_human_a:>15.3f} | {r_human_b:>15.3f} | {p_diff:>12.3f} | {s_a:>15.3f} | {s_b:>15.3f}")

    print("-"*100)
    print(f"{'Model':<10} | {'Beta':<6} | {'p-val':<6} | {'AIC':<6}")
    
    sub_df = df_combined

    for model_name, score_col in [(model_a_name, "model_score_std_a"), (model_b_name, "model_score_std_b")]:

        formula = f"rating_std ~ {score_col} + C(type)" # use non-standardized rating as LMM should handle scaling

        vc_formula = None #{'song': '0 + C(song)'}

        result = smf.mixedlm(formula, sub_df, groups=sub_df["user_id"], vc_formula=vc_formula).fit(reml=False, method='cg') # TODO, once we have more users, try 'lbfgs' again.
        p_val = result.pvalues[score_col]
        aic = result.aic
        beta = result.params[score_col]


        print(f"{model_name:<10} | {beta:>6.3f} | {p_val:>6.3f} | {aic:>6.2f}")

    return df_combined


df_raw = load_raw_user_data('listening_tests/user_ratings/*.json')

available_models = ['cocola', 'clap', 'cdpam','visqol', 'audiobox_ce', 'audiobox_cu', 'audiobox_pc', 'audiobox_pq']

for model in available_models:
    df_final = run_comprehensive_analysis(
        df_raw, 
        f"listening_tests/results_phalar.json", 
        f"listening_tests/results_{model}.json"
    )
    print("\n\n")