import os
import sys
import glob
import yaml
import numpy as np
import pandas as pd
import hashlib


def parse_experiment_name(exp_name):
    """Parse experiment name into parameter dictionary.
    
    Args:
        exp_name (str): Experiment name in format param1_value1__param2_value2__param3_value3
                       where __ separates parameters and _ separates param name from value.
                       Also handles clipped names with hash suffix (e.g. param1_value1__hash_abcdef12).
        
    Returns:
        dict: Dictionary mapping parameter names to their values.
              If the name was clipped (has hash_ part), includes a special
              '_clipped' key with value True.
    """
    # Split by double underscores to get parameter pairs
    param_pairs = exp_name.split('__')
    params = {}
    
    # Check if this is a clipped name (has hash_ part)
    is_clipped = False
    for i, pair in enumerate(param_pairs):
        if pair.startswith('hash_'):
            is_clipped = True
            params['_hash'] = pair[5:]  # Store the hash value without 'hash_' prefix
            param_pairs = param_pairs[:i]  # Only process pairs before the hash
            break
    
    # Process the parameter pairs
    for pair in param_pairs:
        # Split by single underscore to get param name and value
        parts = pair.split('_')
        if len(parts) >= 2:
            param_name = parts[0]
            param_value = '_'.join(parts[1:])  # Join remaining parts in case value contains underscores
            params[param_name] = param_value
    
    # Add clipped flag if this was a clipped name
    if is_clipped:
        params['_clipped'] = True
        
    return params


def clip_experiment_name(exp_id, max_length=128):
    """Clip experiment name if it's too long to prevent path length issues.
    
    Args:
        exp_id (str): Original experiment ID
        max_length (int): Maximum length allowed for the experiment ID
        
    Returns:
        str: Clipped experiment ID if original is too long, otherwise unchanged
    """
    if len(exp_id) <= max_length:
        return exp_id
    
    # Split the experiment name by parameter pairs
    param_pairs = exp_id.split('__')
    
    # Keep adding parameter pairs until we reach the max length
    clipped_pairs = []
    current_length = 0
    
    for pair in param_pairs:
        # Check if adding this pair would exceed the limit
        if current_length + len(pair) + 2 <= max_length - 10:  # Leave room for hash
            clipped_pairs.append(pair)
            current_length += len(pair) + 2  # +2 for the '__' separator
        else:
            break
    
    # Create a hash of the remaining pairs for uniqueness
    remaining_pairs = '__'.join(param_pairs[len(clipped_pairs):])
    hash_suffix = hashlib.md5(remaining_pairs.encode()).hexdigest()[:8]
    
    # Combine the kept pairs with the hash
    clipped_exp_id = '__'.join(clipped_pairs)
    if clipped_pairs:
        clipped_exp_id += f'__hash_{hash_suffix}'
    else:
        clipped_exp_id = f'exp_hash_{hash_suffix}'
    
    print(f"Experiment ID was clipped from {len(exp_id)} to {len(clipped_exp_id)} characters", file=sys.stderr)
    return clipped_exp_id

def analyze_experiment_results(base_dir, group_by=None):
    """Analyze experiment results from multiple runs and compute scores based on permuted matrices.
    
    Args:
        base_dir (str): Base directory containing experiment results
        group_by (list, optional): List of parameter names to group experiments by
        
    The function:
    1. Loads evaluation metrics from all experiment directories
    2. Computes scores based on permuted matrices
    3. Groups results by specified parameters if provided
    4. Saves results to CSV file
    5. Prints summary statistics
    """
    
    # Find all experiment directories
    exp_dirs = glob.glob(os.path.join(base_dir, "*"))
    exp_dirs = [d for d in exp_dirs if os.path.isdir(d)]
    
    results = []
    for exp_dir in exp_dirs:
        metrics_path = os.path.join(exp_dir, "eval_metrics.yaml")
        config_path = os.path.join(exp_dir, "config.yaml")
        
        # Skip if no metrics file exists
        if not os.path.exists(metrics_path):
            print(f"Warning: No metrics file found in {exp_dir}, skipping...", file=sys.stderr)
            continue
            
        # Skip if no config file exists
        if not os.path.exists(config_path):
            print(f"Warning: No config file found in {exp_dir}, skipping...", file=sys.stderr)
            continue
            
        # Load metrics and config files
        with open(metrics_path, 'r') as f:
            metrics = yaml.safe_load(f)
        
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)

        if config is None:
            print(f"Warning: Config file is empty for {exp_dir}, skipping...", file=sys.stderr)
            continue
        
        # Get the permuted matrices
        gt_matrix = np.array(metrics['factor_eval_metrics']['gt_score/permuted_matrix'])
        factor_matrix = np.array(metrics['factor_eval_metrics']['factor_score/permuted_matrix'])
        full_scores = np.clip(np.array(metrics['full_score_metrics']['full_factor_score']), 0, 1)
        full_gt_scores = np.clip(np.array(metrics['full_score_metrics']['full_gt_score']), 0, 1)
        
        # Compute scores for both matrices
        def compute_matrix_score(matrix):
            # For each row, get diagonal value and second maximum
            diagonal_values = np.array(np.diag(matrix))
            # Get second maximum for each row (excluding diagonal)
            row_maxes = []
            for i in range(matrix.shape[0]):
                row = matrix[i, :]
                row[i] = -np.inf  # Exclude diagonal
                second_max = np.max(row)
                row_maxes.append(second_max)
            # Compute score: diagonal * second_max
            # scores = np.sqrt(diagonal_values * (1-np.array(row_maxes)))
            scores = np.mean(diagonal_values) * (1-np.mean(np.array(row_maxes)))
            return np.sqrt(scores)
            # return np.mean(scores)
        
        gt_score = compute_matrix_score(gt_matrix)
        factor_score = compute_matrix_score(factor_matrix)  
        
        # Get experiment name from directory
        exp_name = os.path.basename(exp_dir)
        
        # Use the full exp_id from config which contains all parameters
        # even if the directory name was clipped
        # First try to get original_exp_id which would be the complete, unclipped version
        full_exp_id = config.get('original_exp_id', config.get('exp_id', exp_name))
        params = parse_experiment_name(full_exp_id)
        
        # Also add important top-level config values that might not be in exp_id
        for key in ['rep', 'env', 'lr', 'weight_decay', 'batch_size', 'horizon', 'thickness']:
            if key not in params and key in config:
                params[key] = str(config[key])
        
        # Check for nested parameters in reps config
        if 'reps' in config and isinstance(config['reps'], dict):
            for param_key, param_val in config['reps'].items():
                if isinstance(param_val, dict):
                    for nested_key, nested_val in param_val.items():
                        params[f'reps.{param_key}.{nested_key}'] = str(nested_val)
                else:
                    params[f'reps.{param_key}'] = str(param_val)
        
        # Create result dictionary
        result = {
            'experiment': exp_name,
            'full_exp_id': full_exp_id,
            'gt_score': gt_score,
            'factor_score': factor_score,
            'mean_score': (gt_score + factor_score + np.mean(full_scores) + np.mean(full_gt_scores)) / 4,
            'full_score': np.mean(full_scores),
            'full_gt_score': np.mean(full_gt_scores)
        }
        
        # Add all parameters to the result
        result.update(params)
        results.append(result)
    
    # Convert to DataFrame
    df = pd.DataFrame(results)
    
    if group_by:
        print(f"\nGrouping by parameters: {', '.join(group_by)}")
        print("\nBest experiment for each group:")
        print("-" * 80)
        
        # Check if all group_by columns exist in the dataframe
        missing_cols = [col for col in group_by if col not in df.columns]
        if missing_cols:
            print(f"Warning: Some grouping columns not found in results: {missing_cols}", file=sys.stderr)
            group_by = [col for col in group_by if col in df.columns]
            if not group_by:
                print("No valid grouping columns remain, showing overall best", file=sys.stderr)
                group_by = None
        
        if group_by:
            # Group by specified parameters and find best in each group
            grouped = df.groupby(group_by)
            for group, group_df in grouped:
                best_exp = group_df.loc[group_df['mean_score'].idxmax()]
                # Find best by full score average
                full_score_avg = (group_df['full_score'] + group_df['full_gt_score']) / 2
                best_full_exp = group_df.loc[full_score_avg.idxmax()]
                
                group_dict = dict(zip(group_by, group)) if isinstance(group, tuple) else {group_by[0]: group}
                print(f"\nGroup: {group_dict}")
                print(f"Best experiment: {best_exp['experiment']}")
                print(f"Mean score: {best_exp['mean_score']:.4f}")
                print(f"GT score: {best_exp['gt_score']:.4f}")
                print(f"Factor score: {best_exp['factor_score']:.4f}")
                print(f"Full score avg: {(best_exp['full_score'] + best_exp['full_gt_score'])/2:.4f}")
                print(f"Full score: {best_exp['full_score']:.4f}")
                print(f"Full GT score: {best_exp['full_gt_score']:.4f}")

                print(f"\nBest experiment by full score average: {best_full_exp['experiment']}")
                print(f"Full score avg: {(best_full_exp['full_score'] + best_full_exp['full_gt_score'])/2:.4f}")
                print(f"Full score: {best_full_exp['full_score']:.4f}")
                print(f"Full GT score: {best_full_exp['full_gt_score']:.4f}")
                print(f"Factor score: {best_full_exp['factor_score']:.4f}")
                print(f"GT score: {best_full_exp['gt_score']:.4f}")
        else:
            # Fall back to showing overall best
            df = df.sort_values('mean_score', ascending=False)
            best_exp = df.iloc[0]
            # Find best by full score average
            df['full_score_avg'] = (df['full_score'] + df['full_gt_score']) / 2
            best_full_exp = df.loc[df['full_score_avg'].idxmax()]
            
            print(f"\nBest experiment: {best_exp['experiment']}")
            print(f"Mean score: {best_exp['mean_score']:.4f}")
            print(f"GT score: {best_exp['gt_score']:.4f}")
            print(f"Factor score: {best_exp['factor_score']:.4f}")
            print(f"\nBest experiment by full score average: {best_full_exp['experiment']}")
            print(f"Full score avg: {(best_full_exp['full_score'] + best_full_exp['full_gt_score'])/2:.4f}")
            print(f"Full score: {best_full_exp['full_score']:.4f}")
            print(f"Full GT score: {best_full_exp['full_gt_score']:.4f}")
            print(f"Factor score: {best_full_exp['factor_score']:.4f}")
            print(f"GT score: {best_full_exp['gt_score']:.4f}")
    
    # Save to CSV
    output_path = os.path.join(base_dir, 'experiment_scores.csv')
    df.to_csv(output_path, index=False)
    print(f"\nFull results saved to: {output_path}")
    
    return df

def check_exp_id_consistency(base_dir):
    """Check consistency between directory names and experiment IDs stored in config files.
    
    Args:
        base_dir (str): Base directory containing experiment directories
        
    This function:
    1. Finds all experiment directories in the base directory
    2. Loads the config file from each directory
    3. Compares the directory name with the stored exp_id and original_exp_id
    4. Reports any discrepancies or clipped names
    """
    # Find all experiment directories
    exp_dirs = glob.glob(os.path.join(base_dir, "*"))
    exp_dirs = [d for d in exp_dirs if os.path.isdir(d)]
    
    print(f"Checking {len(exp_dirs)} experiment directories in {base_dir}\n")
    print(f"{'Directory Name':<40} | {'Clipped':<7} | {'Original Length':<15} | {'Config Length':<15}")
    print("-" * 85)
    
    for exp_dir in exp_dirs:
        dir_name = os.path.basename(exp_dir)
        config_path = os.path.join(exp_dir, "config.yaml")
        
        if not os.path.exists(config_path):
            print(f"{dir_name:<40} | {'N/A':<7} | {'No config file':<15} | {'N/A':<15}")
            continue
            
        # Load config file
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        # Get stored experiment IDs
        stored_exp_id = config.get('exp_id', 'N/A')
        original_exp_id = config.get('original_exp_id', 'N/A')
        
        # Check if directory name is clipped (contains hash_)
        is_clipped = 'hash_' in dir_name
        params = parse_experiment_name(dir_name)
        clipped_status = 'Yes' if is_clipped or params.get('_clipped', False) else 'No'
        
        # Print the comparison
        orig_len = len(original_exp_id) if original_exp_id != 'N/A' else 0
        config_len = len(stored_exp_id) if stored_exp_id != 'N/A' else 0
        
        print(f"{dir_name[:37] + '...' if len(dir_name) > 40 else dir_name:<40} | {clipped_status:<7} | {orig_len:<15} | {config_len:<15}")
    
    print("\nSummary:")
    print(f"Total directories: {len(exp_dirs)}")