#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import glob
import json
import numpy as np
import pandas as pd
from math import lgamma
from typing import List, Tuple, Optional, Dict, Any

def _beta_logpdf(x: float, a: float, b: float) -> float:
    if not (0.0 < x < 1.0):
        return -np.inf
    return ((a - 1.0) * np.log(x) + (b - 1.0) * np.log(1.0 - x) - (lgamma(a) + lgamma(b) - lgamma(a + b)))

def _real_roots_in_interval(coeffs, lo=0.0, hi=1.0, tol=1e-12):
    roots = np.roots(coeffs)
    roots = roots[np.isreal(roots)].real
    roots = roots[(roots >= lo - tol) & (roots <= hi + tol)]
    roots.sort()
    return roots.tolist()

def _secant_u(r: float, g: float) -> float:
    return ((1.0 - r)**4 - (1.0 - g)**4) / (r - g)

def solve_all_solutions(RE: float, VE_mean: float, miss_all: float,
                        g_grid: int = 2401, e_abs: float = 1e-10, newton_tol: float = 5e-13) -> Dict[str, Any]:
    A, B, C = float(RE), float(VE_mean), float(miss_all)
    if abs(A - B**4) <= 1e-12 and abs(C - (1.0 - B)**4) <= 1e-12:
        return {'degenerate': True, 'solutions': []}
    g_lo = max(0.0, B - 0.999999)
    g_hi = min(B - 1e-12, 1.0 - 1e-12)
    if not (g_lo < g_hi):
        return {'degenerate': False, 'solutions': []}

    def D1(g): return (A - g**4) / (B - g)
    def D2(g): return (C - (1.0 - g)**4) / (B - g)

    def r_solutions_for_g(g):
        d1 = D1(g)
        coeffs = [1.0, g, g**2, g**3 - d1]
        r_roots = _real_roots_in_interval(coeffs, lo=max(g, B) + 1e-12, hi=1.0 - 1e-12)
        r_roots = [r for r in r_roots if r > g + 1e-12 and r > B + 1e-12]
        return r_roots

    def E_val(g):
        rs = r_solutions_for_g(g)
        if not rs:
            return np.nan
        r = rs[0]
        return _secant_u(r, g) - D2(g)

    gs = np.linspace(g_lo, g_hi, g_grid)
    Es = np.array([E_val(g) for g in gs], dtype=float)

    cand_g = []
    for i in range(len(gs) - 1):
        g1, g2 = gs[i], gs[i+1]
        v1, v2 = Es[i], Es[i+1]
        if np.isnan(v1) or np.isnan(v2):
            continue
        if v1 == 0.0:
            cand_g.append(g1)
        elif v2 == 0.0:
            cand_g.append(g2)
        elif v1 * v2 < 0.0:
            a, b = g1, g2
            fa, fb = v1, v2
            for _ in range(20):
                if abs(fb - fa) < 1e-18:
                    break
                c = b - fb * (b - a) / (fb - fa)
                c = min(max(c, g_lo), g_hi)
                fc = E_val(c)
                if np.isnan(fc):
                    break
                a, fa, b, fb = b, fb, c, fc
                if abs(fc) < newton_tol:
                    break
            cand_g.append(b)

    absE = np.abs(Es)
    for i in range(1, len(gs) - 1):
        if np.isnan(absE[i-1]) or np.isnan(absE[i]) or np.isnan(absE[i+1]):
            continue
        if absE[i] <= absE[i-1] and absE[i] <= absE[i+1] and absE[i] < 1e-6:
            g0 = gs[i]
            g = g0
            for _ in range(30):
                f = E_val(g)
                if np.isnan(f):
                    break
                h = max(1e-6, 1e-3 * max(1e-3, abs(g)))
                fp = E_val(min(g + h, g_hi))
                fm = E_val(max(g - h, g_lo))
                if np.isnan(fp) or np.isnan(fm):
                    break
                d = (fp - fm) / (2 * h)
                if abs(d) < 1e-14:
                    break
                g_new = g - f / d
                g_new = min(max(g_new, g_lo), g_hi)
                if abs(g_new - g) < newton_tol:
                    g = g_new
                    break
                g = g_new
            if not np.isnan(E_val(g)) and abs(E_val(g)) < e_abs:
                cand_g.append(g)

    cand_g = sorted([x for x in cand_g if not np.isnan(x)])
    uniq_g = []
    for g in cand_g:
        if not uniq_g or abs(g - uniq_g[-1]) > 1e-8:
            uniq_g.append(g)

    sols = []
    for g in uniq_g:
        rs = r_solutions_for_g(g)
        if not rs:
            continue
        r = rs[0]
        theta = (B - g) / (r - g)
        if not (0.0 <= theta <= 1.0):
            continue
        if not (g <= B <= r):
            continue
        g = float(np.clip(g, 0.0, 1.0))
        r = float(np.clip(r, 0.0, 1.0))
        theta = float(np.clip(theta, 0.0, 1.0))
        sols.append((theta, r, g))

    sols_sorted = sorted(sols, key=lambda t: (round(t[1], 9), round(t[2], 9), round(t[0], 9)))
    uniq = []
    for th, r, g in sols_sorted:
        if not uniq or (abs(th - uniq[-1][0]) > 1e-8 or abs(r - uniq[-1][1]) > 1e-8 or abs(g - uniq[-1][2]) > 1e-8):
            uniq.append((th, r, g))

    if len(uniq) > 3:
        def residual(triple):
            th, rr, gg = triple
            eq1 = th * (rr**4) + (1 - th) * (gg**4) - A
            eq2 = th * rr + (1 - th) * gg - B
            eq3 = th * ((1 - rr)**4) + (1 - th) * ((1 - gg)**4) - C
            return abs(eq1) + abs(eq2) + abs(eq3)
        uniq = sorted(uniq, key=residual)[:3]

    return {'degenerate': False, 'solutions': uniq}

def pick_map_solution(solutions: List[Tuple[float, float, float]],
                      r_beta: Tuple[float, float] = (8, 2),
                      g_beta: Tuple[float, float] = (2, 8),
                      th_beta: Tuple[float, float] = (1, 1)) -> Tuple[Optional[Tuple[float,float,float]], Optional[float]]:
    if not solutions:
        return None, None
    a_r, b_r = r_beta
    a_g, b_g = g_beta
    a_t, b_t = th_beta
    best = None
    best_lp = -np.inf
    for theta, r, g in solutions:
        lp = (_beta_logpdf(r, a_r, b_r) + _beta_logpdf(g, a_g, b_g) + _beta_logpdf(theta, a_t, b_t))
        if lp > best_lp:
            best_lp = lp
            best = (theta, r, g)
    return best, best_lp

def calculate_metrics_row(row: pd.Series,
                          r_prior=(8,2), g_prior=(2,8), theta_prior=(1,1),
                          report_solutions=True) -> Dict[str, Any]:
    RE = float(row['hit_all'])
    miss_all = float(row['miss_all'])
    rot0, rot90, rot180, rot270 = (float(row['rot0']), float(row['rot90']), float(row['rot180']), float(row['rot270']))
    VE_mean = (rot0 + rot90 + rot180 + rot270) / 4.0
    hit_any = 1.0 - miss_all
    out = solve_all_solutions(RE, VE_mean, miss_all)
    solutions = out['solutions'] if not out['degenerate'] else []
    map_sol, _ = pick_map_solution(solutions, r_prior, g_prior, theta_prior)
    A_adj_map = np.nan
    A_adj_min = np.nan
    A_adj_max = np.nan
    A_adj_list = []
    if solutions:
        for (th_i, r_i, g_i) in solutions:
            if (1 - g_i) > 0:
                A_adj_i = th_i * r_i
            else:
                A_adj_i = np.nan
            A_adj_list.append({'theta': float(th_i), 'r': float(r_i), 'g': float(g_i), 'A_adj': float(A_adj_i)})
        A_adjs = [x['A_adj'] for x in A_adj_list if not np.isnan(x['A_adj'])]
        if A_adjs:
            A_adj_min, A_adj_max = float(np.min(A_adjs)), float(np.max(A_adjs))
        if map_sol is not None:
            theta, r, g = map_sol
            if (1 - g) > 0:
                A_adj_map = theta * r
            else:
                A_adj_map = np.nan
    result = {
        'RE': RE,
        'VE_mean': VE_mean,
        'miss_all': miss_all,
        'hit_any': hit_any,
        'num_solutions': len(solutions),
        'degenerate_rg': False,
        'theta_map': np.nan,
        'r_map': np.nan,
        'g_map': np.nan,
        'A_adj_map': A_adj_map,
        'A_adj_min': A_adj_min,
        'A_adj_max': A_adj_max,
        'solutions_json': json.dumps(A_adj_list, ensure_ascii=False)
    }
    if map_sol is not None:
        theta, r, g = map_sol
        result.update({'theta_map': float(theta), 'r_map': float(r), 'g_map': float(g)})
    return result

def process_all_models(input_glob: str = 're_results/**/*_REresult.csv',
                       r_prior=(8,2), g_prior=(2,8), theta_prior=(1,1)) -> pd.DataFrame:
    all_rows = []
    csv_files = glob.glob(input_glob, recursive=True)
    for csv_file in csv_files:
        model_name = os.path.basename(csv_file).replace('_REresult.csv', '')
        try:
            df = pd.read_csv(csv_file)
        except Exception:
            continue
        for _, row in df.iterrows():
            if str(row.get('category', '')) == 'Average':
                continue
            try:
                metrics = calculate_metrics_row(row, r_prior, g_prior, theta_prior, report_solutions=True)
            except Exception:
                continue
            record = {'model': model_name, 'category': row.get('category'), 'total_questions': row.get('total_questions')}
            record.update(metrics)
            all_rows.append(record)
    results_df = pd.DataFrame(all_rows)
    num_cols = ['RE','VE_mean','miss_all','hit_any','theta_map','r_map','g_map','A_adj_map','A_adj_min','A_adj_max']
    for col in num_cols:
        if col in results_df.columns:
            results_df[col] = results_df[col].astype(float).round(10)
    return results_df

if __name__ == '__main__':
    results = process_all_models(
        input_glob='aggregated_results/**/*_REresult.csv',
        r_prior=(8,2),
        g_prior=(2,8),
        theta_prior=(1,1)
    )
    out_csv = 'reliable_metrics.csv'
    results.to_csv(out_csv, index=False)
    if not results.empty:
        print(f"Saved: {out_csv}")
    else:
        print("No rows produced.")
