#!/usr/bin/env python
"""
Simplified timepoint regression analysis script.
Runs OLS regression for each timepoint using specified dimensionality types.
"""

import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy import stats
from statsmodels.stats.multitest import fdrcorrection
from typing import Dict, List, Tuple
import sys

class SimpleTimepointRegression:
    """Simplified timepoint regression analysis."""
    
    def __init__(self, proj_path: str, batch_name: str, output_subdir: str):
        """
        Initialize the analysis with project paths.
        
        Args:
            proj_path: Base project path
            batch_name: Batch directory name (e.g., 'slurm_cpu_batch_20250604_111254')
            output_subdir: Subdirectory containing dimensionality pickle files
        """
        self.proj_path = Path(proj_path)
        self.batch_dir = self.proj_path / 'results' / batch_name
        self.dim_output_dir = self.proj_path / 'results' / output_subdir
        self.analysis_output_dir = self.proj_path / 'results' / 'timepoint_regression_analysis'
        self.analysis_output_dir.mkdir(parents=True, exist_ok=True)
        
        # Add src to path for imports
        sys.path.append(str(self.proj_path / 'src'))
        
        # Experimental parameters
        self.scales = [0.01, 0.041617914502878176, 0.17320508075688776, 0.7208434242404266, 3.0]
        self.seeds = [42, 123, 234, 345, 456, 567, 678, 789, 890, 901]
        self.training_mode = 'maximal'
        self.optimizer = 'sgd'
        self.n_timepoints = 10
        
        # Timepoint labels for plots
        self.timepoint_labels = ['Rule', 'Stim1_e', 'Stim1_l', 'Dly1_e', 'Dly1_l', 
                                'Stim2_e', 'Stim2_l', 'Dly2_e', 'Dly2_l', 'Resp']
        
        # Data containers
        self.generalization_df = None
        self.dimensionality_arrays = {}
        self.regression_results = {}
        
        print(f"Initialized TimepointRegressionAnalysis")
        print(f"Batch directory: {self.batch_dir}")
        print(f"Dimensionality data: {self.dim_output_dir}")
        print(f"Analysis output:: {self.analysis_output_dir}")
    
    def extract_generalization_accuracies(self) -> pd.DataFrame:
        """Extract generalization accuracies using the same approach as user's code."""
        from utils import load_experiment_results
        
        print("Extracting generalization accuracies...")
        
        results_list = []
        missing_dirs = []
        
        for scale in self.scales:
            for seed in self.seeds:
                # Construct path exactly like user's approach
                results_path = str(self.batch_dir) + '/' + f'cpu_experiment_{self.training_mode}_scale{scale}_{self.optimizer}_seed{seed}/'
                
                # Check if directory exists
                if not Path(results_path).exists():
                    missing_dirs.append(results_path)
                    print(f"Missing directory: {results_path}")
                    continue
                
                try:
                    # Load results using user's function
                    results = load_experiment_results(Path(results_path))
                    
                    # Extract final generalization accuracy using user's approach
                    test_acc_history = results['results'][str(scale)][str(seed)]['test_task_acc']
                    
                    if len(test_acc_history) > 0:
                        final_acc = test_acc_history[-1]  # Final epoch value
                        
                        results_list.append({
                            'scale': scale,
                            'seed': seed,
                            'generalization_acc': final_acc,
                            'scale_idx': self.scales.index(scale),
                            'seed_idx': self.seeds.index(seed)
                        })
                        
                        print(f"✓ Scale {scale}, Seed {seed}: {final_acc:.4f}")
                    else:
                        print(f"✗ Scale {scale}, Seed {seed}: Empty test_acc_history")
                        
                except Exception as e:
                    print(f"✗ Scale {scale}, Seed {seed}: Error loading - {e}")
        
        # Convert to DataFrame and sort by scale_idx, seed_idx for consistent ordering
        df = pd.DataFrame(results_list)
        df = df.sort_values(['scale_idx', 'seed_idx']).reset_index(drop=True)
        
        print(f"\n{'='*50}")
        print(f"GENERALIZATION ACCURACY EXTRACTION SUMMARY")
        print(f"{'='*50}")
        print(f"Successfully extracted: {len(df)} / 50 experiments")
        print(f"Missing directories: {len(missing_dirs)}")
        
        if len(df) > 0:
            print(f"\nGeneralization accuracy statistics:")
            print(f"Mean: {df['generalization_acc'].mean():.4f}")
            print(f"Std:  {df['generalization_acc'].std():.4f}")
            print(f"Min:  {df['generalization_acc'].min():.4f}")
            print(f"Max:  {df['generalization_acc'].max():.4f}")
            
            # Show breakdown by scale
            print(f"\nBreakdown by scale:")
            scale_summary = df.groupby('scale')['generalization_acc'].agg(['count', 'mean', 'std']).round(4)
            print(scale_summary)
        
        self.generalization_df = df
        
        # Save generalization data
        df.to_csv(self.analysis_output_dir / 'generalization_accuracies.csv', index=False)
        
        return df
    
    def load_dimensionality_arrays(self):
        """Load dimensionality arrays from pickle files."""
        print("\nLoading dimensionality arrays...")
        
        dimensionality_files = {
            'global_dim.': 'global_dim_arr.pkl',
            # 'stimulus': 'stimulus_dim_arr.pkl',
            # 'context': 'task_dim_arr.pkl',
            # 'response': 'motor_response_dim_arr.pkl'
        }
        
        for dim_name, filename in dimensionality_files.items():
            filepath = self.dim_output_dir / filename
            
            if filepath.exists():
                with open(filepath, 'rb') as f:
                    array = pickle.load(f)
                    self.dimensionality_arrays[dim_name] = array
                    print(f"✓ Loaded {dim_name}: shape {array.shape}")
            else:
                print(f"❌ Missing: {filepath}")
                raise FileNotFoundError(f"Required dimensionality file not found: {filepath}")
        
        # Validate array shapes
        expected_shape = (5, 1, 10, 10)  # (scales, training_modes, seeds, timepoints)
        for dim_name, array in self.dimensionality_arrays.items():
            if array.shape != expected_shape:
                print(f"⚠️  Warning: {dim_name} has shape {array.shape}, expected {expected_shape}")
    
    def prepare_regression_data(self, timepoint: int) -> Tuple[np.ndarray, np.ndarray, List[str]]:
        """Prepare regression data for a specific timepoint."""
        predictors = {}
        feature_names = []
        
        for dim_name, array in self.dimensionality_arrays.items():
            # Extract timepoint t, training mode 0 (maximal)
            # Shape: (5 scales, 1 training_mode, 10 seeds, 10 timepoints)
            dim_data = array[:, 0, :, timepoint]  # Shape: (5, 10)
            dim_flat = dim_data.flatten()  # Shape: (50,)
            
            predictors[dim_name] = dim_flat
            feature_names.append(dim_name)
        
        # Create predictor matrix
        X = np.column_stack([predictors[name] for name in feature_names])
        
        # Get outcomes (generalization accuracies)
        y = self.generalization_df['generalization_acc'].values
        
        # Verify we have the right number of samples
        if len(y) != X.shape[0]:
            raise ValueError(f"Mismatch: {len(y)} generalization values, {X.shape[0]} dimensionality values")
        
        return X, y, feature_names
    
    def run_timepoint_regression(self, timepoint: int) -> Dict:
        """Run OLS regression for a specific timepoint."""
        X, y, feature_names = self.prepare_regression_data(timepoint)
        
        # Standardize features and outcome
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        y_scaler = StandardScaler()
        y_scaled = y_scaler.fit_transform(y.reshape(-1, 1)).flatten()
        
        # Fit OLS regression
        model = LinearRegression()
        model.fit(X_scaled, y_scaled)
        
        # Calculate metrics
        y_pred = model.predict(X_scaled)
        r2 = r2_score(y_scaled, y_pred)
        
        # Cross-validation R²
        cv_scores = cross_val_score(model, X_scaled, y_scaled, cv=5, scoring='r2')
        
        # Calculate p-values using t-test
        n_samples = len(y_scaled)
        n_features = X_scaled.shape[1]
        
        # Calculate residuals and MSE
        residuals = y_scaled - y_pred
        mse = np.mean(residuals**2)
        
        # Calculate standard errors and t-statistics
        try:
            # Calculate covariance matrix
            X_with_intercept = np.column_stack([np.ones(n_samples), X_scaled])
            XtX_inv = np.linalg.inv(X_with_intercept.T @ X_with_intercept)
            var_beta = mse * np.diag(XtX_inv)[1:]  # Exclude intercept
            se_beta = np.sqrt(var_beta)
            
            # Calculate t-statistics and p-values
            t_stats = model.coef_ / se_beta
            df = n_samples - n_features - 1
            p_values = 2 * (1 - stats.t.cdf(np.abs(t_stats), df))
            
        except Exception as e:
            print(f"Warning: Could not calculate p-values for timepoint {timepoint}: {e}")
            p_values = np.full(n_features, np.nan)
        
        # Store results
        results = {
            'timepoint': timepoint,
            'model': model,
            'scaler': scaler,
            'y_scaler': y_scaler,
            'feature_names': feature_names,
            'coefficients': dict(zip(feature_names, model.coef_)),
            'r2_train': r2,
            'cv_r2_mean': cv_scores.mean(),
            'cv_r2_std': cv_scores.std(),
            'n_samples': len(X),
            'p_values': dict(zip(feature_names, p_values)) if not np.isnan(p_values).all() else None
        }
        
        return results
    
    def run_all_timepoint_regressions(self):
        """Run regression for all timepoints."""
        print("Running timepoint regressions...")
        
        for t in range(self.n_timepoints):
            print(f"  Timepoint {t} ({self.timepoint_labels[t]})...")
            self.regression_results[t] = self.run_timepoint_regression(t)
        
        # Apply global FDR correction
        self._apply_global_fdr_correction()
        
        print("All regressions completed!")
    
    def _apply_global_fdr_correction(self):
        """Apply global FDR correction across all timepoints and regressors."""
        print("Applying global FDR correction...")
        
        # Collect all p-values
        all_p_values = []
        p_value_indices = []  # (timepoint, feature_idx) pairs
        
        feature_names = list(self.dimensionality_arrays.keys())
        
        for t in range(self.n_timepoints):
            if self.regression_results[t]['p_values'] is not None:
                for feat_idx, feature in enumerate(feature_names):
                    if feature in self.regression_results[t]['p_values']:
                        p_val = self.regression_results[t]['p_values'][feature]
                        if not np.isnan(p_val):
                            all_p_values.append(p_val)
                            p_value_indices.append((t, feat_idx, feature))
        
        if len(all_p_values) == 0:
            print("No valid p-values found for FDR correction")
            return
        
        # Apply FDR correction
        all_p_values = np.array(all_p_values)
        significant, corrected_p = fdrcorrection(all_p_values, alpha=0.05, method='indep')
        
        print(f"FDR correction applied to {len(all_p_values)} tests")
        print(f"Significant effects: {significant.sum()}/{len(all_p_values)}")
        
        # Store corrected results back in regression_results
        for i, (t, feat_idx, feature) in enumerate(p_value_indices):
            if 'fdr_corrected' not in self.regression_results[t]:
                self.regression_results[t]['fdr_corrected'] = {
                    'corrected_p_values': {},
                    'significant': {}
                }
            
            self.regression_results[t]['fdr_corrected']['corrected_p_values'][feature] = corrected_p[i]
            self.regression_results[t]['fdr_corrected']['significant'][feature] = significant[i]
    
    def create_results_summary(self) -> pd.DataFrame:
        """Create summary DataFrame of all regression results."""
        summary_data = []
        
        for t, results in self.regression_results.items():
            row = {
                'timepoint': t,
                'timepoint_label': self.timepoint_labels[t],
                'r2_train': results['r2_train'],
                'r2_cv_mean': results['cv_r2_mean'],
                'r2_cv_std': results['cv_r2_std'],
                'n_samples': results['n_samples']
            }
            
            # Add coefficients
            for feature, coef in results['coefficients'].items():
                row[f'beta_{feature}'] = coef
            
            # Add raw p-values if available
            if results['p_values']:
                for feature, pval in results['p_values'].items():
                    row[f'p_raw_{feature}'] = pval
            
            # Add FDR-corrected p-values and significance if available
            if 'fdr_corrected' in results:
                for feature, corrected_p in results['fdr_corrected']['corrected_p_values'].items():
                    row[f'p_fdr_{feature}'] = corrected_p
                
                for feature, is_significant in results['fdr_corrected']['significant'].items():
                    row[f'sig_fdr_{feature}'] = is_significant
            
            summary_data.append(row)
        
        df = pd.DataFrame(summary_data)
        
        # Save summary
        df.to_csv(self.analysis_output_dir / 'regression_summary.csv', index=False)
        
        return df
    
    def plot_r_squared(self, figsize: Tuple[int, int] = (12, 6)):
        """Plot R² values across timepoints."""
        if not self.regression_results:
            raise ValueError("Must run regressions first")
        
        # Extract data
        timepoints = list(range(self.n_timepoints))
        r2_train = [self.regression_results[t]['r2_train'] for t in timepoints]
        
        # Create plot
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot training R²
        ax.plot(timepoints, r2_train, 'o-', linewidth=2.5, markersize=8, 
               color='darkblue', label='R²')
        
        # Customize plot
        # ax.set_title('Model Performance Across Timepoints', fontsize=32)
        ax.set_xlabel('Timepoint', fontsize=26)
        ax.set_ylabel('R² Score', fontsize=26)
        ax.set_xticks(timepoints)
        ax.set_xticklabels(self.timepoint_labels, fontsize=20, rotation=-45)
        
        # Make tick labels and axes thicker and larger
        ax.tick_params(axis='both', which='major', labelsize=24, width=2, length=6)
        for spine in ax.spines.values():
            spine.set_linewidth(2)
        
        # Add legend and grid
        ax.legend(fontsize=24, frameon=True, loc='best')
        ax.grid(True, alpha=0.3)
        
        # Remove top and right spines
        sns.despine(ax=ax)
        
        plt.tight_layout()
        
        # Save plot
        plt.savefig(self.analysis_output_dir / 'r_squared_across_timepoints.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
        
        return fig
    
    def plot_beta_coefficients(self, figsize: Tuple[int, int] = (7, 20)):
        """Plot beta coefficients across timepoints with FDR-corrected significance markers."""
        if not self.regression_results:
            raise ValueError("Must run regressions first")
        
        # Extract data
        timepoints = list(range(self.n_timepoints))
        feature_names = list(self.dimensionality_arrays.keys())
        
        # Set up figure with 3 rows, 2 columns for 4 features
        fig, axes = plt.subplots(1,3, figsize=figsize)
        
        # Color palette
        colors = ['blue','red','darkgreen']
        # colors = ['magenta']
        
        for i, feature in enumerate(feature_names):
            ax = axes[i]
            # ax = axes
            
            # Extract beta coefficients for this feature
            betas = [self.regression_results[t]['coefficients'][feature] for t in timepoints]
            
            # Extract FDR-corrected significance for this feature
            significant = []
            corrected_p_values = []
            for t in timepoints:
                if ('fdr_corrected' in self.regression_results[t] and 
                    feature in self.regression_results[t]['fdr_corrected']['significant']):
                    significant.append(self.regression_results[t]['fdr_corrected']['significant'][feature])
                    corrected_p_values.append(self.regression_results[t]['fdr_corrected']['corrected_p_values'][feature])
                else:
                    significant.append(False)
                    corrected_p_values.append(np.nan)
            
            # Plot beta trajectory
            ax.plot(timepoints, betas, 'o-', color=colors[i], linewidth=2.5, 
                   markersize=8, label=f'{feature}')
            
            # Add horizontal line at zero
            ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
            
            # Add significance markers (all black as requested)
            data_ymin, data_ymax = np.min(betas), np.max(betas)
            y_range = data_ymax - data_ymin
            padding = 0.1 * y_range
            ax.set_ylim(data_ymin - padding, data_ymax + padding)
            
            # Recalculate y_range and offset after setting limits
            current_ylim = ax.get_ylim()
            y_range = current_ylim[1] - current_ylim[0]
            y_offset = 0.05 * y_range
            
            for t in timepoints:
                if significant[t] and not np.isnan(corrected_p_values[t]):
                    p_val = corrected_p_values[t]
                    if p_val < 0.001:
                        marker, size = '***', 12
                    elif p_val < 0.01:
                        marker, size = '**', 12
                    else:
                        marker, size = '*', 12
                    
                    # Position marker above or below point based on beta value
                    y_pos = np.max(betas) + y_offset
                    
                    ax.text(t, y_pos, marker, ha='center', va='center', 
                           fontsize=size, color='black', fontweight='bold')
            
            # Customize plot appearance
            ax.set_title(f'{feature.replace("_", " ").title()}', fontsize=32)
            ax.set_xlabel('Timepoints', fontsize=26)
            ax.set_ylabel('Coefficient', fontsize=26)
            ax.set_xticks(timepoints)
            ax.set_xticklabels(self.timepoint_labels, fontsize=20, rotation=-45)
            
            # Make tick labels and axes thicker and larger
            ax.tick_params(axis='both', which='major', labelsize=24, width=2, length=6)
            for spine in ax.spines.values():
                spine.set_linewidth(2)
            
            # Remove top and right spines
            sns.despine(ax=ax)
        
        # Add significance legend to the first subplot
        # legend_text = 'FDR Correction\n* p<0.05, ** p<0.01, *** p<0.001'
        # # axes[0].text(1.05, 1, legend_text, transform=axes[0].transAxes, 
        # #             verticalalignment='top', fontsize=24, 
        # #             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        # axes.text(1.05, 1, legend_text, transform=axes.transAxes,
        #             verticalalignment='top', fontsize=24, 
        #             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # plt.suptitle('Gen. Accuracy ~ Global Dimensionality', 
        #              fontsize=36, y=0.98)
        plt.tight_layout()
        
        # Save plot
        plt.savefig(self.analysis_output_dir / 'beta_coefficients_across_timepoints.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
        
        return fig
    
    def run_complete_analysis(self):
        """Run the complete simplified analysis pipeline."""
        print("="*60)
        print("SIMPLIFIED TIMEPOINT REGRESSION ANALYSIS")
        print("="*60)
        
        # Step 1: Extract generalization accuracies
        self.extract_generalization_accuracies()
        
        # Step 2: Load dimensionality arrays
        self.load_dimensionality_arrays()
        
        # Step 3: Run timepoint regressions
        self.run_all_timepoint_regressions()
        
        # Step 4: Create summary
        summary_df = self.create_results_summary()
        print(f"\nRegression summary saved to: regression_summary.csv")
        
        # Step 5: Generate plots
        print(f"\nGenerating plots...")
        
        # Print significance summary
        self._print_significance_summary()
        
#         # R² plot
#         self.plot_r_squared()
        
#         # Beta coefficients plot
#         self.plot_beta_coefficients()
        
        print(f"\nAnalysis complete! Results saved to: {self.analysis_output_dir}")
        
        return {
            'generalization_df': self.generalization_df,
            'dimensionality_arrays': self.dimensionality_arrays,
            'regression_results': self.regression_results,
            'summary_df': summary_df
        }
    
    def _print_significance_summary(self):
        """Print summary of significant effects after FDR correction."""
        print("\nSIGNIFICANCE SUMMARY (FDR GLOBAL CORRECTION):")
        
        feature_names = list(self.dimensionality_arrays.keys())
        total_significant = 0
        total_tests = 0
        
        for feature in feature_names:
            sig_timepoints = []
            feature_total = 0
            
            for t in range(self.n_timepoints):
                if ('fdr_corrected' in self.regression_results[t] and 
                    feature in self.regression_results[t]['fdr_corrected']['significant']):
                    feature_total += 1
                    total_tests += 1
                    
                    if self.regression_results[t]['fdr_corrected']['significant'][feature]:
                        sig_timepoints.append(self.timepoint_labels[t])
                        total_significant += 1
            
            if sig_timepoints:
                print(f"  {feature}: {len(sig_timepoints)}/{feature_total} timepoints significant ({', '.join(sig_timepoints)})")
            else:
                print(f"  {feature}: No significant timepoints")
        
        print(f"\nOverall: {total_significant}/{total_tests} tests significant after FDR correction")


# Example usage
if __name__ == "__main__":
    # Set up paths (update these to match your setup)
    proj_path = '/home/ln275/f_mc1689_1/cpro-rnn/docs/scripts/'
    batch_name = 'slurm_cpu_batch_20250604_111254'
    output_subdir = 'analysis_outputs_fixed/norm_manip_only_h2h_seqlen10_allstim'
    
    # Initialize and run analysis
    analysis = SimpleTimepointRegression(proj_path, batch_name, output_subdir)
    results = analysis.run_complete_analysis()
