"""
Linear Star Plots: |beta| vs Horizon H

Generates star plots showing the relationship between convergence rate (|beta|)
and effective horizon H across multiple datasets and layers.

Two versions:
  - V1: Uniform round markers for all layers
  - V2: Dataset-specific markers (circle, square, triangle, diamond)

Both versions highlight specific points (German L0, L1, L5) as stars.
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy import stats

# ============================================================
# CONFIGURATION
# ============================================================
BASE_DIR = Path(__file__).parent.parent.parent.parent
OUTPUT_DIR = Path(__file__).parent / "figures"

# Search directories for hyperparameter data
SEARCH_DIRS = {
    'baseline_wiki': [
        BASE_DIR / 'results_baseline_wiki_k500',
        BASE_DIR / 'results_from_cluster',
    ],
    'lang_chinese': [
        BASE_DIR / 'results_lang_chinese_k500',
        BASE_DIR / 'results_lang_chinese',
        BASE_DIR / 'results_from_cluster',
    ],
    'lang_german': [
        BASE_DIR / 'results_lang_german_k500',
        BASE_DIR / 'results_lang_german',
        BASE_DIR / 'results_from_cluster',
    ],
    'domain_news': [
        BASE_DIR / 'results_domain_news',
        BASE_DIR / 'results_from_cluster',
    ],
}

DATASET_PATTERNS = {
    'baseline_wiki': 'baseline_wiki_en',
    'lang_chinese': 'lang_chinese',
    'lang_german': 'lang_german',
    'domain_news': 'domain_news',
}

# Layers to exclude from specific datasets
EXCLUDE_LAYERS = [
    ('lang_chinese', 7),
    ('lang_chinese', 8),
    ('lang_chinese', 9),
    ('lang_german', 8),
    ('domain_news', 6),
]

# Points to highlight with stars
STAR_POINTS = [
    ('lang_german', 0),
    ('lang_german', 1),
    ('lang_german', 5),
]

STAR_COLORS = {
    ('lang_german', 0): '#FFAB20',
    ('lang_german', 1): '#D97030',
    ('lang_german', 5): '#B81D24',
}

# Dataset visual configuration
DATASET_MARKERS = {
    'baseline_wiki': 'o',
    'lang_chinese': 's',
    'lang_german': '^',
    'domain_news': 'D'
}

DATASET_LABELS = {
    'baseline_wiki': 'Wiki',
    'lang_chinese': 'Chinese',
    'lang_german': 'German',
    'domain_news': 'News'
}

# Color palette for layers (0-11)
LAYER_COLORS = plt.cm.viridis(np.linspace(0, 1, 12))


# ============================================================
# DATA LOADING
# ============================================================
def find_all_hyperparameters(dataset_name, search_dirs, exclude_layers):
    """Find hyperparameters.csv files and extract data for each layer."""
    layer_data = {}
    pattern = DATASET_PATTERNS[dataset_name]

    for search_dir in search_dirs:
        if not search_dir.exists():
            continue
        for hp_file in search_dir.rglob('hyperparameters.csv'):
            if pattern not in str(hp_file):
                continue
            try:
                df = pd.read_csv(hp_file)
                if df.empty:
                    continue
                row = df.iloc[0]
                layer_id = int(row['layer_id'])
                k = int(row['k'])
                if (dataset_name, layer_id) in exclude_layers:
                    continue
                # Keep data with highest k value
                if layer_id not in layer_data or k > layer_data[layer_id]['k']:
                    layer_data[layer_id] = {
                        'layer_id': layer_id,
                        'H': row['H'],
                        'k': k,
                        'max_eigen': row['max_eigen'],
                        'slope_mean': row['slope_mean'],
                        'slope_mean_se': row['slope_mean_se'],
                        'slope_cov': row['slope_cov'],
                        'slope_cov_se': row['slope_cov_se'],
                    }
            except Exception:
                pass
    return layer_data


def correct_se(se, k):
    """Correct standard error by sqrt(k)."""
    return se / np.sqrt(k)


def load_all_data():
    """Load data from all datasets."""
    all_data = []
    for dataset, dirs in SEARCH_DIRS.items():
        layer_data = find_all_hyperparameters(dataset, dirs, EXCLUDE_LAYERS)
        for layer_id, data in layer_data.items():
            data['dataset'] = dataset
            data['slope_mean_se_corrected'] = correct_se(data['slope_mean_se'], data['k'])
            data['slope_cov_se_corrected'] = correct_se(data['slope_cov_se'], data['k'])
            all_data.append(data)

    return pd.DataFrame(all_data).sort_values(['dataset', 'layer_id'])


# ============================================================
# PLOTTING FUNCTIONS
# ============================================================
def create_star_plot_v1(df_all, slope_col, se_col, slope_name, output_dir):
    """
    V1: All round markers for layers, stars for highlighted points.
    """
    fig, ax = plt.subplots(figsize=(7, 5))

    for layer_id in range(12):
        df_layer = df_all[df_all['layer_id'] == layer_id]
        for _, row in df_layer.iterrows():
            is_star = (row['dataset'], layer_id) in STAR_POINTS
            if is_star:
                star_color = STAR_COLORS[(row['dataset'], layer_id)]
                ax.errorbar(row['H'], np.abs(row[slope_col]), yerr=row[se_col],
                           fmt='*', color=star_color,
                           markersize=18, capsize=2, alpha=0.95,
                           markeredgecolor='black', markeredgewidth=0.8, zorder=10)
            else:
                ax.errorbar(row['H'], np.abs(row[slope_col]), yerr=row[se_col],
                           fmt='o', color=LAYER_COLORS[layer_id],
                           markersize=9, capsize=2, alpha=0.85,
                           markeredgecolor='white', markeredgewidth=0.5)

    # Linear fit
    H = df_all['H'].values
    beta = np.abs(df_all[slope_col].values)
    slope_fit, intercept_fit, r_value, _, _ = stats.linregress(H, beta)
    H_line = np.linspace(H.min()*0.95, H.max()*1.05, 100)
    beta_line = intercept_fit + slope_fit * H_line
    ax.plot(H_line, beta_line, '--', color='darkred', linewidth=2.5)

    ax.set_xlabel('Horizon H', fontsize=16)
    ax.set_ylabel(f'|β| ({slope_name})', fontsize=16)
    ax.grid(True, alpha=0.3)

    # Legend
    layer_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=LAYER_COLORS[i],
                                markersize=10) for i in range(12)]
    star_handles = [
        plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='#B81D24', markersize=14),
        plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='#D97030', markersize=14),
        plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='#FFAB20', markersize=14),
    ]
    fit_handle = plt.Line2D([0], [0], color='darkred', linestyle='--', linewidth=2)

    all_handles = layer_handles + star_handles + [fit_handle]
    all_labels = [f'L{i}' for i in range(12)] + ['German L5', 'German L1', 'German L0'] + [f'Fit (R={r_value:.2f})']

    ax.legend(all_handles, all_labels, loc='upper right', fontsize=9, ncol=2)

    plt.tight_layout()
    name = slope_name.lower()
    plt.savefig(output_dir / f'slope_vs_H_stars_v1_{name}.png', dpi=150, bbox_inches='tight')
    plt.savefig(output_dir / f'slope_vs_H_stars_v1_{name}.pdf', bbox_inches='tight')
    print(f"Saved: {output_dir}/slope_vs_H_stars_v1_{name}.png")
    plt.close()


def create_star_plot_v2(df_all, slope_col, se_col, slope_name, output_dir):
    """
    V2: Dataset-specific markers, stars for highlighted points.
    """
    fig, ax = plt.subplots(figsize=(7, 5))

    for layer_id in range(12):
        df_layer = df_all[df_all['layer_id'] == layer_id]
        for _, row in df_layer.iterrows():
            is_star = (row['dataset'], layer_id) in STAR_POINTS
            if is_star:
                star_color = STAR_COLORS[(row['dataset'], layer_id)]
                ax.errorbar(row['H'], np.abs(row[slope_col]), yerr=row[se_col],
                           fmt='*', color=star_color,
                           markersize=18, capsize=2, alpha=0.95,
                           markeredgecolor='black', markeredgewidth=0.8, zorder=10)
            else:
                ax.errorbar(row['H'], np.abs(row[slope_col]), yerr=row[se_col],
                           fmt=DATASET_MARKERS[row['dataset']], color=LAYER_COLORS[layer_id],
                           markersize=9, capsize=2, alpha=0.85,
                           markeredgecolor='white', markeredgewidth=0.5)

    # Linear fit
    H = df_all['H'].values
    beta = np.abs(df_all[slope_col].values)
    slope_fit, intercept_fit, r_value, _, _ = stats.linregress(H, beta)
    H_line = np.linspace(H.min()*0.95, H.max()*1.05, 100)
    beta_line = intercept_fit + slope_fit * H_line
    ax.plot(H_line, beta_line, '--', color='darkred', linewidth=2.5)

    ax.set_xlabel('Horizon H', fontsize=16)
    ax.set_ylabel(f'|β| ({slope_name})', fontsize=16)
    ax.grid(True, alpha=0.3)

    # Legend
    layer_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=LAYER_COLORS[i],
                                markersize=10) for i in range(12)]
    ds_handles = [plt.Line2D([0], [0], marker=m, color='gray', linestyle='None',
                             markersize=6) for m in DATASET_MARKERS.values()]
    star_handles = [
        plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='#B81D24', markersize=14),
        plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='#D97030', markersize=14),
        plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='#FFAB20', markersize=14),
    ]
    fit_handle = plt.Line2D([0], [0], color='darkred', linestyle='--', linewidth=2)

    all_handles = layer_handles + ds_handles + star_handles + [fit_handle]
    all_labels = ([f'L{i}' for i in range(12)] + list(DATASET_LABELS.values()) +
                  ['German L5', 'German L1', 'German L0'] + [f'Fit (R={r_value:.2f})'])

    ax.legend(all_handles, all_labels, loc='upper right', fontsize=9, ncol=2)

    plt.tight_layout()
    name = slope_name.lower()
    plt.savefig(output_dir / f'slope_vs_H_stars_v2_{name}.png', dpi=150, bbox_inches='tight')
    plt.savefig(output_dir / f'slope_vs_H_stars_v2_{name}.pdf', bbox_inches='tight')
    print(f"Saved: {output_dir}/slope_vs_H_stars_v2_{name}.png")
    plt.close()


# ============================================================
# MAIN
# ============================================================
def main():
    print("=" * 60)
    print("Generating Linear Star Plots")
    print("=" * 60)

    # Create output directory
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Load data
    print("\nLoading data from results directories...")
    df_all = load_all_data()
    print(f"Loaded {len(df_all)} data points")

    # Generate plots for both mean and covariance
    slope_types = [
        ('slope_mean', 'slope_mean_se_corrected', 'Mean'),
        ('slope_cov', 'slope_cov_se_corrected', 'Covariance')
    ]

    print("\nGenerating plots...")
    for slope_col, se_col, slope_name in slope_types:
        create_star_plot_v1(df_all, slope_col, se_col, slope_name, OUTPUT_DIR)
        create_star_plot_v2(df_all, slope_col, se_col, slope_name, OUTPUT_DIR)

    print("\nDone!")


if __name__ == "__main__":
    main()
