import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shelve
import argparse
from pathlib import Path
import sys
from typing import List, Dict, Tuple
from ast import literal_eval

# Set font and PDF rendering preferences
import matplotlib
matplotlib.rcParams.update({
    "font.family": "serif",
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})
plt.rcParams['font.size'] = 20


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Run sampling experiment on validation and test data.')
    
    parser.add_argument('--val-path', type=str, required=True,
                        help='Path to validation CSV file')
    parser.add_argument('--test-path', type=str, required=True,
                        help='Path to test CSV file')
    parser.add_argument('--val-path-random', type=str, required=True,
                        help='Path to random validation CSV file')
    parser.add_argument('--test-path-random', type=str, required=True,
                        help='Path to random test CSV file')
    parser.add_argument('--output-path', type=str, required=True,
                        help='Output directory path')
    parser.add_argument('--num-draws', type=int, required=True,
                        help='Number of random draws per sample size')
    parser.add_argument('--sample-sizes', type=int, nargs='+', required=True,
                        help='List of sample sizes to test')
    parser.add_argument('--name', type=str, required=True,
                        help='Name to include in plot title (e.g., "context")')
    parser.add_argument('--full-computation', action='store_true',
                        help='Run full computation including test scores')
    
    return parser.parse_args()

def load_data(val_path: str, test_path: str, val_path_random: str, test_path_random: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Load validation and test data from CSV files."""
    val_df = pd.read_csv(val_path)
    test_df = pd.read_csv(test_path)
    val_df_random = pd.read_csv(val_path_random)
    test_df_random = pd.read_csv(test_path_random)
    
    # Validate column names
    expected_columns = {'key', 'average_score'}
    for name, df in [("Validation", val_df), ("Test", test_df), 
                     ("Random Validation", val_df_random), ("Random Test", test_df_random)]:
        if not expected_columns.issubset(df.columns):
            raise ValueError(f"{name} CSV must contain columns: {expected_columns}")
    
    return val_df, test_df, val_df_random, test_df_random

def run_sampling_experiment(val_df: pd.DataFrame, test_df: pd.DataFrame, 
                          num_draws: int, sample_sizes: List[int],
                          output_path: str, full_computation: bool) -> Dict:
    """Run the main sampling experiment."""
    output_dir = Path(output_path)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Dictionary to store results
    results = {}
    top_scoring_keys = {}
    
    print("Starting sampling experiment...")
    
    for sample_size in sample_sizes:
        print(f"\nProcessing sample size: {sample_size}")
        
        if sample_size > len(val_df):
            print(f"Warning: Sample size {sample_size} is larger than validation set size {len(val_df)}. "
                  f"Using full validation set.")
            sample_size = len(val_df)
        
        top_keys_for_size = []
        
        for draw in range(num_draws):
            # Draw random sample from validation set
            # This can be either replace=False or replace=True depending on your experiment needs
            sample = val_df.sample(n=sample_size, replace=True)
            
            # Find top scoring key in the sample - FIXED to handle potential Series
            top_idx = sample['average_score'].idxmax()
            
            # Handle the case where top_idx might be a Series
            if isinstance(top_idx, pd.Series):
                top_idx = top_idx.iloc[0]  # Take the first index if multiple max values exist
                
            top_key = sample.loc[top_idx, 'key']
            
            # Ensure top_key is a scalar value
            if isinstance(top_key, pd.Series):
                top_key = top_key.iloc[0]  # Take the first value if multiple keys exist
                
            top_keys_for_size.append(top_key)
            
            if draw % 10 == 0:  # Progress indicator
                print(f"  Completed {draw}/{num_draws} draws")
        
        top_scoring_keys[sample_size] = top_keys_for_size
        
        print(f"Completed all draws for sample size {sample_size}")
    
    # Save top scoring keys to shelve database
    shelve_path = output_dir / 'top_scoring_keys.db'
    with shelve.open(str(shelve_path)) as db:
        for sample_size, keys in top_scoring_keys.items():
            db[str(sample_size)] = keys
    
    print(f"\nTop scoring keys saved to {shelve_path}")
    
    if not full_computation:
        print("\nStopping at this point (full computation not requested).")
        print("To continue with test score computation, run with --full-computation flag.")
        return None
    
    # Continue with full computation
    print("\nComputing test scores for top scoring keys...")
    
    results_data = []
    
    # Create mapping from key to test score for faster lookup
    test_score_map = dict(zip(test_df['key'], test_df['average_score']))
    
    for sample_size in sample_sizes:
        test_scores = []
        
        for key in top_scoring_keys[sample_size]:
            # Ensure key is a scalar value (hashable)
            if isinstance(key, pd.Series):
                key = key.iloc[0]
                
            # Get test score for this key
            test_score = test_score_map.get(key)
            
            if test_score is None:
                print(f"Warning: Key '{key}' not found in test set. Skipping.")
                continue
            
            test_scores.append(test_score)
        
        results_data.append({
            'sample_size': sample_size,
            'test_scores': test_scores
        })
    
    # Create results DataFrame
    results_df = pd.DataFrame(results_data)
    
    # Save results to CSV
    csv_path = output_dir / 'experiment_results.csv'
    results_df.to_csv(csv_path, index=False)
    print(f"\nResults saved to {csv_path}")
    
    return results_df

def create_summary_plot_matplotlib(
    results_df: pd.DataFrame,
    results_df_random: pd.DataFrame,
    output_path: str,
    name: str,
    *,
    use_tex: bool = True,
    font_size: int = 20,
    ci: str = "95ci"  # options: "95ci" or "sem"
):
    """
    Create a clean PDF plot (matplotlib-only) showing, for each sample size:
      - mean test score (point),
      - error bar (95% CI or SEM),
      - line connecting means.
    X-axis is logarithmic base 2.
    Overlays both regular and random results on the same plot.

    Parameters
    ----------
    results_df : pd.DataFrame
        Output of run_sampling_experiment with columns:
          - 'sample_size' (int)
          - 'test_scores' (list[float] or stringified list)
    results_df_random : pd.DataFrame
        Same format as results_df but for random baseline
    output_path : str
        Directory to write 'summary_plot.pdf'.
    name : str
        Name to include in plot title (e.g., "context")
    use_tex : bool
        If True, use LaTeX text rendering so fonts match your LaTeX doc.
    font_size : int
        Base font size for labels/ticks.
    ci : str
        "95ci" for 1.96*SEM bars, or "sem" to show standard error.
    """

    # Configure matplotlib to use LaTeX (for matching document fonts)
    plt.rcParams.update({
        'text.usetex': bool(use_tex),
        'font.size': font_size,
        'axes.titlesize': font_size,
        'axes.labelsize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
    })

    def process_results_df(df):
        """Convert results_df to aggregated statistics."""
        # Normalize results_df to long form: one row per (sample_size, test_score)
        rows = []
        for _, row in df.iterrows():
            ss = int(row['sample_size'])
            ts = row['test_scores']
            if isinstance(ts, str):
                # results were saved as a stringified Python list
                ts = literal_eval(ts)
            # ts is now a list of floats
            for s in ts:
                rows.append((ss, float(s)))

        long_df = pd.DataFrame(rows, columns=['sample_size', 'test_score'])
        if long_df.empty:
            raise ValueError("No test scores found to plot.")

        # Aggregate per sample size
        grouped = (
            long_df
            .groupby('sample_size', sort=True)
            .agg(
                n=('test_score', 'size'),
                mean=('test_score', 'mean'),
                std=('test_score', 'std')
            )
            .reset_index()
            .sort_values('sample_size')
        )
        grouped['sem'] = grouped['std'] / np.sqrt(grouped['n'])
        if ci == "95ci":
            grouped['yerr'] = 1.96 * grouped['sem']
        elif ci == "sem":
            grouped['yerr'] = grouped['sem']
        else:
            raise ValueError("ci must be '95ci' or 'sem'.")
        
        return grouped

    # Process both datasets
    grouped = process_results_df(results_df)
    grouped_random = process_results_df(results_df_random)

    # Plot
    plt.figure(figsize=(6.0, 4.0))  # single, clean chart
    ax = plt.gca()

    # Plot main results
    x = grouped['sample_size'].to_numpy()
    y = grouped['mean'].to_numpy()
    yerr = grouped['yerr'].to_numpy()
    
    ax.errorbar(
        x, y, yerr=yerr,
        fmt='o',            # point marker
        capsize=3,          # small caps on error bars
        linewidth=1,
        label='DCT'
    )
    ax.plot(x, y, '-', linewidth=1)

    # Plot random results
    x_random = grouped_random['sample_size'].to_numpy()
    y_random = grouped_random['mean'].to_numpy()
    yerr_random = grouped_random['yerr'].to_numpy()
    
    ax.errorbar(
        x_random, y_random, yerr=yerr_random,
        fmt='s',            # square marker for distinction
        capsize=3,
        linewidth=1,
        label='Random'
    )
    ax.plot(x_random, y_random, '--', linewidth=1)

    # Logarithmic x-axis (base 2) to assess linearity / diminishing returns
    ax.set_xscale('log', base=2)

    ax.set_xlabel(r'Sample size')
    ax.set_ylabel(r'Test score')
    ax.set_title(name)

    ax.grid(True, alpha=0.3)
    ax.legend()

    # Save as PDF for LaTeX inclusion
    out_dir = Path(output_path)
    out_dir.mkdir(parents=True, exist_ok=True)
    pdf_path = out_dir / 'summary_plot.pdf'
    plt.tight_layout()
    plt.savefig(pdf_path, bbox_inches='tight')
    plt.close()

    print(f"Saved plot to {pdf_path}")

def main():
    """Main execution function."""
    args = parse_arguments()
    
    try:
        # Load data
        val_df, test_df, val_df_random, test_df_random = load_data(
            args.val_path, args.test_path, args.val_path_random, args.test_path_random
        )
        print(f"Loaded validation data: {len(val_df)} rows")
        print(f"Loaded test data: {len(test_df)} rows")
        print(f"Loaded random validation data: {len(val_df_random)} rows")
        print(f"Loaded random test data: {len(test_df_random)} rows")
        
        # Run experiment for main data
        print("\n=== Running experiment on main data ===")
        results_df = run_sampling_experiment(
            val_df=val_df,
            test_df=test_df,
            num_draws=args.num_draws,
            sample_sizes=args.sample_sizes,
            output_path=args.output_path,
            full_computation=args.full_computation
        )
        
        # Run experiment for random data
        print("\n=== Running experiment on random data ===")
        results_df_random = run_sampling_experiment(
            val_df=val_df_random,
            test_df=test_df_random,
            num_draws=args.num_draws,
            sample_sizes=args.sample_sizes,
            output_path=args.output_path + "_random",
            full_computation=args.full_computation
        )
        
        # Create visualization if full computation was performed
        if args.full_computation and results_df is not None and results_df_random is not None:
            create_summary_plot_matplotlib(
                results_df,
                results_df_random,
                args.output_path,
                args.name,
                use_tex=False,    # set to True if you have LaTeX installed
                font_size=20,     # tweak to match your paper
                ci="95ci"         # or "sem"
            )
            print("\nExperiment completed successfully!")
        
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()