import pandas as pd
import numpy as np
from pathlib import Path
import sys


sys.path.append(str(Path(__file__).parent))
from binning_utils_v5 import (
    get_binning_for_T,
    create_noise_sensitivity_matrix,
    fit_constrained_parseval,
    compute_derived_features
)


class JuntaAnalysis:
    """Rerun analysis with V5 square-aligned binning."""

    def __init__(self, data_path, output_dir):
        """
        Initialize analyzer for a single CSV file.
        
        Args:
            data_path: Path to the CSV file
            output_dir: Directory to save results
        """
        print(f"Loading data from: {data_path}")
        self.df = pd.read_csv(data_path)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        print(f"Loaded {len(self.df)} rows")
        print(f"Output directory: {self.output_dir}")

        # Extract delta values from NS column names in the data
        ns_cols = [col for col in self.df.columns if col.startswith('NS_')]
        self.delta_values = np.array([float(col.replace('NS_', '')) for col in ns_cols])
        self.ns_cols = ns_cols

        # Drop rows with NaN in NS columns or time_factor
        self.df = self.df.dropna(subset=ns_cols + ['time_factor'])
        print(f"Found {len(self.delta_values)} NS measurements")
       

    def fit_junta_for_T(self, T, weight_penalty=5.0, ridge_alpha=0.001):
        """
        Fit junta coefficients for all functions at given T using V5 binning.

        Args:
            T: Temporal depth
            weight_penalty: Weight for Parseval constraint
            ridge_alpha: Ridge regularization strength

        Returns:
            DataFrame with W_k columns added
        """
      

        # Get V5 binning for this T
        bin_info = get_binning_for_T(T)
        bins = bin_info['bins']
        k_max = bin_info['k_max']
        n_bins = bin_info['n_bins']

      

        # Create NS matrix
        A = create_noise_sensitivity_matrix(bins, self.delta_values, n=256)
        

        # Get functions for this T
        df_T = self.df[self.df['time_factor'] == T].copy()
      

        if len(df_T) == 0:
            print(f"No data for T={T}, skipping")
            return df_T

        # Fit W_k for each function
        W_k_all = []
        residuals = []

        for i, idx in enumerate(df_T.index):

            ns_values = df_T.loc[idx, self.ns_cols].values.astype(np.float64)

            # Fit constrained W_k
            weights, residual = fit_constrained_parseval(
                ns_values, A,
                weight_penalty=weight_penalty,
                ridge_alpha=ridge_alpha
            )

            W_k_all.append(weights)
            residuals.append(residual)

        print(f"Completed all {len(df_T)} functions")

        # Convert to array
        W_k_all = np.array(W_k_all)

        # Add W_k columns to DataFrame
        for i, (start, end) in enumerate(bins):
            col_name = f"W_{start}_{end}"
            df_T[col_name] = W_k_all[:, i]


        df_T['fit_residual'] = residuals

        ns_matrix = df_T[self.ns_cols].values

        # Predicted NS
        ns_pred = W_k_all @ A.T

        # Calculate NRMSE for each function
        mse = np.mean((ns_matrix - ns_pred)**2, axis=1)
        rmse = np.sqrt(mse)
        df_T['fit_rmse'] = rmse

        # Compute derived features
        df_T = compute_derived_features(df_T, bins)

        # Summary statistics
        print(f"\nFit quality:")
        print(f"  Mean RMSE = {rmse.mean():.4f} ± {rmse.std():.4f}")
        print(f"  Mean residual = {np.mean(residuals):.6f}")
        print(f"  Sum(W_k): {df_T['sum_Wk'].mean():.6f} ± {df_T['sum_Wk'].std():.6f}")
        print(f"  Avg degree: {df_T['avg_degree'].mean():.2f} ± {df_T['avg_degree'].std():.2f}")

        return df_T

    def run_full_analysis(self):
        """Run constrained analysis for all T values."""

      

        all_dfs = []

        for T in [2, 3, 4, 5, 6, 7]:
            df_T = self.fit_junta_for_T(T)
            if len(df_T) > 0:
                all_dfs.append(df_T)

        if not all_dfs:
            print("No data found for any T value")
            return None

       

        df_final = pd.concat(all_dfs, ignore_index=True)

        # Save to CSV
        output_path = self.output_dir / "data_with_v5_binning.csv"
        df_final.to_csv(output_path, index=False)
        print(f"\nSaved results to: {output_path}")
        print(f"  Total rows: {len(df_final)}")

        # Create summary table
        self.create_summary_table(df_final)

        return df_final

    def create_summary_table(self, df):
        """Create summary statistics table."""

        summary_rows = []

        for T in sorted(df['time_factor'].unique()):
            df_T = df[df['time_factor'] == T]

            # Get binning info
            bin_info = get_binning_for_T(T)

            row = {
                'T': T,
                'k_max': bin_info['k_max'],
                'n_bins': bin_info['n_bins'],
                'N': len(df_T),
                'mean_fit_rmse': df_T['fit_rmse'].mean(),
                'std_fit_rmse': df_T['fit_rmse'].std(),
                'mean_sum_Wk': df_T['sum_Wk'].mean(),
                'std_sum_Wk': df_T['sum_Wk'].std(),
                'mean_avg_degree': df_T['avg_degree'].mean(),
                'std_avg_degree': df_T['avg_degree'].std(),
                'mean_W_high': df_T['W_high'].mean(),
                'std_W_high': df_T['W_high'].std(),
            }
            summary_rows.append(row)

        summary_df = pd.DataFrame(summary_rows)

        # Save summary
        summary_path = self.output_dir / "summary_v5_binning.csv"
        summary_df.to_csv(summary_path, index=False)
        print(f"\nSaved summary to: {summary_path}")

        # Print summary
        print("\n" + "="*60)
        print("SUMMARY")
        print("="*60)
        print(summary_df.to_string(index=False))


def process_all_csvs(input_folder, output_base="output_all"):
    """
    Process all CSV files in the input folder.
    
    Args:
        input_folder: Path to folder containing CSV files
        output_base: Base directory for outputs (default: "output_all")
    
    Results are saved to: {output_base}/{csv_stem}/
    """
    input_folder = Path(input_folder)
    output_base = Path(output_base)
    
    # Find all CSV files
    csv_files = sorted(input_folder.glob("*.csv"))
    
    if not csv_files:
        print(f"No CSV files found in: {input_folder}")
        return
 
    
    for i, csv_path in enumerate(csv_files, 1):
       
        
        # Create output directory named after CSV file (without extension)
        csv_name = csv_path.stem
        output_dir = output_base / csv_name
        
        try:
            analyzer = JuntaAnalysis(csv_path, output_dir)
            analyzer.run_full_analysis()
            
        except Exception as e:
            print(f"Error processing {csv_name}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
   
    print(f"Results saved to: {output_base}/")


def main():
   
    input_folder = ""
    output_base = ""
    
   
    if not Path(input_folder).exists():
        print(f"Input folder not found: {input_folder}")
        return
    
    process_all_csvs(input_folder, output_base)


if __name__ == "__main__":
    main()