import argparse
import json
import numpy as np
import os
import re


def parse_arguments():
    parser = argparse.ArgumentParser(description='Display accuracy vs learning rate statistics from JSON experiment records')
    parser.add_argument('--record_path', type=str, required=True, 
                        help='Path to the JSON experiment record file')
    parser.add_argument('--min_lr', type=float, default=None,
                        help='Minimum learning rate to include')
    parser.add_argument('--max_lr', type=float, default=None,
                        help='Maximum learning rate to include')
    parser.add_argument('--default_lr_only', type=bool, default=True,
                        help='Only include learning rates with mantissa in [1.1247, 2.0, 3.5566, 6.3246] (any exponent)')
    
    return parser.parse_args()


def get_mantissa(lr):
    """Extract the mantissa from a learning rate in scientific notation."""
    if lr == 0:
        return 0
    exponent = int(np.floor(np.log10(abs(lr))))
    mantissa = lr / (10 ** exponent)
    return mantissa


def _fmt_pct(x: float) -> str:
    """Format a fraction (0~1) as percentage with two decimals."""
    return f"{x * 100:.2f}%"


def _fmt_pct_list(vals):
    """Format a list of fractions (0~1) as percentages with two decimals."""
    return "[" + ", ".join(f"{v * 100:.2f}%" for v in vals) + "]"


def load_data(json_path, task, min_lr=None, max_lr=None, default_lr_only=True):
    """
    Load JSON data and prepare series for analysis.
    Returns:
        series: dict mapping series names to lr, mean, and std arrays (aggregated).
        breakdown_data: dict containing raw sub-task data for reporting.
    """
    with open(json_path, 'r') as f:
        data = json.load(f)

    default_mantissas = [1.1247, 2.0, 3.5566, 6.3246]
    
    series = {}
    breakdown_data = {}

    label_map = {
        'lora': 'LoRA',
        'pissa': 'PiSSA', 
        'milora': 'MiLoRA', 
        'dora' : 'DoRA',
        'initab': 'InitAB'
    }

    for top_key, methods in data.items():
        for method_name, records in methods.items():
            
            key = method_name.lower()
            label = label_map.get(key, method_name)

            if label not in breakdown_data:
                breakdown_data[label] = {}

            lr_acc_dict = {}
            
            for rec in records:
                lr = rec.get('hyparam', {}).get('lr')
                
                if lr is None:
                    continue
                    
                if default_lr_only:
                    mantissa = get_mantissa(lr)
                    if not any(np.isclose(mantissa, default_m, rtol=1e-4) for default_m in default_mantissas):
                        continue
                
                if not default_lr_only:
                    if min_lr is not None and lr < min_lr:
                        continue
                    if max_lr is not None and lr > max_lr:
                        continue
                
                if lr not in breakdown_data[label]:
                    breakdown_data[label][lr] = {}

                for k, v in rec.items():
                    if k.startswith('acc-') and isinstance(v, dict):
                        trial_num = int(k.split('-')[1])
                        
                        if 'metamath' in task:
                            math_acc = v.get('math')
                            gsm8k_acc = v.get('gsm8k')
                            if math_acc is None or gsm8k_acc is None:
                                continue 
                            
                            acc = (math_acc + gsm8k_acc) / 2.0

                            breakdown_data[label][lr].setdefault('MATH', []).append(math_acc)
                            breakdown_data[label][lr].setdefault('GSM8K', []).append(gsm8k_acc)

                        elif 'python' in task:
                            humaneval_acc = v.get('humaneval', [None])[0]
                            mbpp_acc = v.get('mbpp', [None])[0]
                            if humaneval_acc is None or mbpp_acc is None:
                                continue
                            
                            acc = (humaneval_acc + mbpp_acc) / 2.0

                            breakdown_data[label][lr].setdefault('HumanEval', []).append(humaneval_acc)
                            breakdown_data[label][lr].setdefault('MBPP', []).append(mbpp_acc)

                        else:
                            acc = 0.0
                        
                        if lr not in lr_acc_dict:
                            lr_acc_dict[lr] = []
                        lr_acc_dict[lr].append((trial_num, acc))
            
            if not lr_acc_dict:
                continue
            
            lrs = []
            means = []
            stds = []
            
            for lr, trial_acc_pairs in lr_acc_dict.items():
                trial_acc_pairs.sort(key=lambda x: x[0])
                acc_values = [acc for _, acc in trial_acc_pairs]
                arr = np.array(acc_values)
                lrs.append(lr)
                means.append(arr.mean())
                stds.append(arr.std(ddof=1) if len(acc_values) > 1 else 0.0)
            
            lrs = np.array(lrs)
            means = np.array(means)
            stds = np.array(stds)
            order = np.argsort(lrs)

            runs_sorted = {}
            for i in order:
                lr = float(lrs[i])
                trial_acc_pairs = lr_acc_dict[lrs[i]]
                trial_acc_pairs.sort(key=lambda x: x[0])
                runs_sorted[lr] = [acc for _, acc in trial_acc_pairs]
            
            series[label] = {
                'lrs': lrs[order],
                'means': means[order],
                'stds': stds[order],
                'runs': runs_sorted,
            }
                
    return series, breakdown_data


def normalize_lr_bounds(min_lr, max_lr):
    if min_lr is not None and max_lr is not None and min_lr > max_lr:
        min_lr, max_lr = max_lr, min_lr
    return min_lr, max_lr


def extract_config_from_filename(filename):
    basename = os.path.splitext(os.path.basename(filename))[0]
    pattern = r'([^_]+)_(.+)_rank(\d+)_bs(\d+)'
    match = re.match(pattern, basename)
    if match:
        task = match.group(1)
        model = match.group(2).replace('_', '-')
        rank = int(match.group(3))
        batch_size = int(match.group(4))
        return task, model, rank, batch_size
    return None, None, None, None


def print_subtask_breakdown(breakdown_data, task_name):
    """Print detailed breakdown of sub-tasks for each LR."""
    if not breakdown_data:
        return

    first_method = next(iter(breakdown_data))
    first_lr = next(iter(breakdown_data[first_method]))
    subtask_keys = sorted(breakdown_data[first_method][first_lr].keys())

    if not subtask_keys:
        print("No sub-task details available.")
        return

    for st in subtask_keys:
        print("\n" + "="*70)
        print(f"SUB-TASK PERFORMANCE: {st}")
        print("="*70)
        
        for method, lr_data in breakdown_data.items():
            print(f"\n[{method}]")
            
            print(f"{'LR':>12} | {'Mean':>10} | {'StdDev':>10} | Runs")
            print("-"*70)

            sorted_lrs = sorted(lr_data.keys())
            for lr in sorted_lrs:
                vals = lr_data[lr].get(st, [])
                if vals:
                    arr = np.array(vals)
                    mean = arr.mean()
                    std = arr.std(ddof=1) if len(vals) > 1 else 0.0
                    run_vals_str = _fmt_pct_list(vals)
                    print(f"{lr:>12.2e} | {_fmt_pct(mean):>10} | {_fmt_pct(std):>10} | {run_vals_str}")
                else:
                    print(f"{lr:>12.2e} | {'N/A':>10} | {'N/A':>10} | []")


def print_per_lr_table(series: dict):
    """Print aggregated performance per method."""
    print("\n" + "="*70)
    print("AGGREGATED PERFORMANCE (per method)")
    print("="*70)
    for method, data in series.items():
        lrs = data['lrs']
        means = data['means']
        stds = data['stds']
        runs = data.get('runs', {})
        order = np.argsort(lrs)

        print(f"\n[{method}]")
        print(f"{'LR':>12} | {'Mean':>10} | {'StdDev':>10} | Runs")
        print("-"*70)
        for idx in order:
            lr = float(lrs[idx])
            mean = means[idx]
            std = stds[idx]
            run_vals = runs.get(lr, [])
            print(f"{lr:>12.2e} | {_fmt_pct(mean):>10} | {_fmt_pct(std):>10} | {_fmt_pct_list(run_vals)}")


def main():
    args = parse_arguments()
    args.min_lr, args.max_lr = normalize_lr_bounds(args.min_lr, args.max_lr)
    
    if not os.path.exists(args.record_path):
        print(f"Error: Record file '{args.record_path}' not found.")
        return
    
    task, model, rank, batch_size = extract_config_from_filename(args.record_path)
    
    if None in [task, model, rank, batch_size]:
        print(f"Warning: Could not extract configuration from filename '{args.record_path}'")
        task, model, rank, batch_size = "unknown", "unknown", 0, 0
    
    series, breakdown_data = load_data(args.record_path, task, args.min_lr, args.max_lr, args.default_lr_only)
    
    if not series:
        print("No valid data series found.")
        return
    
    print(f"\nAnalyzing: {args.record_path}")
    print(f"Task: {task}, Model: {model}, Rank: {rank}, Batch Size: {batch_size}")
    print(f"Data series found: {list(series.keys())}")
    
    if 'metamath' in task or 'python' in task:
        print_subtask_breakdown(breakdown_data, task)
    
    print_per_lr_table(series)    
    
    print("\n" + "="*70)
    print("BEST PERFORMANCE FOR EACH METHOD")
    print("="*70)

    best_results = []
    
    for method, data in series.items():
        best_idx = np.argmax(data['means'])
        best_accuracy = data['means'][best_idx] * 100
        best_lr = data['lrs'][best_idx]
        best_std = data['stds'][best_idx] * 100
        
        num_runs = len(data['runs'].get(float(best_lr), []))

        best_results.append({
            'method': method,
            'accuracy': best_accuracy,
            'lr': best_lr,
            'std': best_std,
            'num_runs': num_runs
        })
        
        print(f"{method:12s} | Best Accuracy: {best_accuracy:.2f}% (±{best_std:.2f}%) | "
              f"Learning Rate: {best_lr:.2e} | Runs: {num_runs}")
    
    print("\n" + "="*70)
    print("RANKING (by Best Accuracy)")
    print("="*70)
    
    best_results_sorted = sorted(best_results, key=lambda x: x['accuracy'], reverse=True)
    
    for rank, result in enumerate(best_results_sorted, 1):
        print(f"{rank}. {result['method']:12s} | {result['accuracy']:.2f}% (±{result['std']:.2f}%) | "
              f"LR: {result['lr']:.2e} | Runs: {result['num_runs']}")
    
    print("="*70)


if __name__ == '__main__':
    main()