import os
import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt
from collections import defaultdict
import json
# Hard-coded dictionary for number of steps each task should have
MAX_STEPS = {
    'continual_cifar100': 20,
    'continual_imagenet': 100,
    'permuted_MNIST': 200,
    'random_label_cifar10': 100,
    'random_MNIST': 50,
    'shuffle_cifar10': 100
}

# Tasks that don't have test accuracies
NO_TEST_ACCURACY = ['random_MNIST', 'random_label_cifar10']

def extract_run_info(column_name):
    """Extract run information from column name.
    
    Examples:
    - BaseAgent_CNN_adam_0.001_continual_cifar100_transform=['flip', 'crop', 'norm']_1 - train_acc_continual_cifar100
    - NeuroSyncAgent_MLP_adam_0.001_random_MNIST_transform=['flip', 'crop', 'norm']_1
    - DeepFourierAgent_MLP_sgd_0.01_permuted_MNIST_transform=['flip', 'crop', 'norm']_1
    - EWCAgent_CNN_sgd_0.01_continual_imagenet_transform=['flip', 'crop', 'norm']_1
    - PReLUAgent_MLP_sgd_0.01_shuffle_cifar10_transform=['flip', 'crop', 'norm']_2
    - CReLUAgent_CNN_adam_0.001_continual_cifar100_transform=['flip', 'crop', 'norm']_2
    - NeuroSyncAgent_MLP_adam_0.001_random_label_cifar10_transform=['flip', 'crop', 'norm']_2
    """
    # The pattern matches both formats with and without the metric suffix
    base_pattern = r'(.+?)_(.+?)_(.+?)_(.+?)_(.+?)_(.+?)_transform=(.+?)_(\d+)'
    
    # Try to match with metric suffix first
    full_pattern = base_pattern + r' - (.+?)_(.+)'
    match = re.match(full_pattern, column_name)
    
    if match:
        agent, model, optimizer, lr, task_part1, task_part2, transform, seed, metric_type, metric_task = match.groups()
        task = f"{task_part1}_{task_part2}" if task_part2 else task_part1
    else:
        # Try the pattern without metric suffix
        match = re.match(base_pattern, column_name)
        if match:
            agent, model, optimizer, lr, task_part1, task_part2, transform, seed = match.groups()
            task = f"{task_part1}_{task_part2}" if task_part2 else task_part1
            metric_type = None
            metric_task = None
        else:
            # Handle 'random_label' task which has an extra underscore
            pattern_with_extra = r'(.+?)_(.+?)_(.+?)_(.+?)_(.+?)_(.+?)_(.+?)_transform=(.+?)_(\d+)'
            match = re.match(pattern_with_extra, column_name)
            if match:
                agent, model, optimizer, lr, task_part1, task_part2, task_part3, transform, seed = match.groups()
                task = f"{task_part1}_{task_part2}_{task_part3}"
                metric_type = None
                metric_task = None
            else:
                return None
    
    return {
        'agent': agent,
        'model': model,
        'optimizer': optimizer,
        'lr': lr,
        'task': task,
        'transform': transform,
        'seed': int(seed),
        'metric_type': metric_type if 'metric_type' in locals() else None,
        'metric_task': metric_task if 'metric_task' in locals() else None,
        'config': f"{optimizer}_{lr}",
        'optimizer_only': optimizer  # Keep just the optimizer name
    }

def find_csv_files(base_dir):
    """Find all train and test accuracy CSV files in the directory structure."""
    csv_files = {}
    
    for task in os.listdir(base_dir):
        task_path = os.path.join(base_dir, task)
        if not os.path.isdir(task_path):
            continue
            
        csv_files[task] = {}
        
        for baseline in os.listdir(task_path):
            baseline_path = os.path.join(task_path, baseline)
            if not os.path.isdir(baseline_path):
                continue
                
            csv_files[task][baseline] = {
                'train': None,
                'test': None
            }
            
            for file in os.listdir(baseline_path):
                if file == 'train_accuracy.csv':
                    csv_files[task][baseline]['train'] = os.path.join(baseline_path, file)
                elif file == 'test_accuracy.csv':
                    csv_files[task][baseline]['test'] = os.path.join(baseline_path, file)
    
    return csv_files


def process_csv_file(file_path):
    """Process a CSV file and extract relevant columns."""
    if not file_path or not os.path.exists(file_path):
        return None
        
    try:
        df = pd.read_csv(file_path)
    except:
        return None, None
    
    # Filter out columns ending with min or max
    filtered_cols = [col for col in df.columns if not (col.endswith('min') or col.endswith('max'))]
    if 'Step' in filtered_cols:
        filtered_cols.remove('Step')
        
    filtered_df = df[['Step'] + filtered_cols]
    
    # Create a dictionary to store columns by config and seed
    cols_by_config_seed = defaultdict(dict)
    
    for col in filtered_cols:
        info = extract_run_info(col)
        if info:
            config = info['config']
            seed = info['seed']
            cols_by_config_seed[config][seed] = col
    
    return df, cols_by_config_seed

def average_by_seed(df, cols_by_config_seed):
    """Average columns by seed for each configuration."""
    result = {}
    variance = {}
    
    for config, seed_cols in cols_by_config_seed.items():
        if not seed_cols:
            continue
            
        # Get columns for this configuration
        columns = list(seed_cols.values())
        
        # Create a new dataframe with just the steps and these columns
        config_df = df[['Step'] + columns].copy()
        
        # For each step, calculate the average of available values
        avg_values = []
        var_values = []
        
        for _, row in config_df.iterrows():
            values = [row[col] for col in columns if not pd.isna(row[col])]
            if values:
                avg_values.append(np.mean(values))
                var_values.append(np.var(values) if len(values) > 1 else 0)
            else:
                avg_values.append(np.nan)
                var_values.append(np.nan)
        
        result[config] = pd.Series(avg_values, index=config_df.index)
        variance[config] = pd.Series(var_values, index=config_df.index)
    
    return result, variance

def select_best_config(avg_by_config, steps):
    """Select the best configuration based on average performance."""
    best_config = None
    best_avg = -float('inf')
    
    for config, avg_series in avg_by_config.items():
        # Consider only values up to the specified number of steps
        valid_values = avg_series.iloc[:min(len(avg_series), steps)]
        if not valid_values.empty:
            config_avg = valid_values.mean()
            if config_avg > best_avg:
                best_avg = config_avg
                best_config = config
    
    # If only one config is available, return it
    if best_config is None and avg_by_config:
        best_config = list(avg_by_config.keys())[0]
        
    return best_config

def process_task_baselines(csv_files, task):
    """Process all baselines for a specific task."""
    task_data = csv_files[task]
    max_steps = MAX_STEPS.get(task, 400)  # Default to 400 if not specified
    
    # Process train accuracy files
    train_avg_by_baseline = {}
    train_var_by_baseline = {}
    best_configs = {}
    non_df = None
    for baseline, files in task_data.items():
        if files['train']:
            df, cols_by_config_seed = process_csv_file(files['train'])
            
            if df is not None:
                # Trim data to max_steps if necessary
                step_indices = df['Step'].values
                number_of_steps = len(step_indices)
                # max_step_idx = np.searchsorted(step_indices, max_steps, side='right')
               
                if max_steps < len(df):
                    df = df.iloc[:max_steps].copy()
                
                # Average by seed for each configuration
                avg_by_config, var_by_config = average_by_seed(df, cols_by_config_seed)
                
                # Select the best configuration
                best_config = select_best_config(avg_by_config, max_steps)
                
                if best_config:
                    best_configs[baseline] = best_config
                    train_avg_by_baseline[baseline] = avg_by_config[best_config]
                    train_var_by_baseline[baseline] = var_by_config[best_config]
                
                # if task == 'random_MNIST':
                #     breakpoint()
                non_df = df
    
    # Create a DataFrame with all baseline averages
    if train_avg_by_baseline:
        steps = non_df['Step'].values
        train_df = pd.DataFrame({'Step': steps})
        
        for baseline, avg_series in train_avg_by_baseline.items():
            # Ensure the series has the same length as steps
            if len(avg_series) < len(steps):
                # Pad with NaN
                padded = pd.Series([np.nan] * len(steps), index=range(len(steps)))
                padded.iloc[:len(avg_series)] = avg_series
                train_avg_by_baseline[baseline] = padded
                
            train_df[baseline] = train_avg_by_baseline[baseline].values
        
        # Renumber steps from 1 to n
        train_df['Step'] = range(1, len(train_df) + 1)
    else:
        train_df = pd.DataFrame()
    
    # Process test accuracy files (if available)
    test_df = pd.DataFrame()
    test_var_by_baseline = {}
    non_df = None
    if task not in NO_TEST_ACCURACY:
        test_avg_by_baseline = {}
        
        for baseline, files in task_data.items():
            if files['test'] and baseline in best_configs:
                df, cols_by_config_seed = process_csv_file(files['test'])
                if df is not None:
                    # Trim data to max_steps if necessary
                    step_indices = df['Step'].values
                    max_step_idx = np.searchsorted(step_indices, max_steps, side='right')
                    if max_step_idx < len(df):
                        df = df.iloc[:max_step_idx].copy()
                    
                    # Average by seed for the best configuration
                    avg_by_config, var_by_config = average_by_seed(df, cols_by_config_seed)
                    
                    best_config = best_configs[baseline]
                    if best_config in avg_by_config:
                        test_avg_by_baseline[baseline] = avg_by_config[best_config]
                        test_var_by_baseline[baseline] = var_by_config[best_config]
                    non_df = df
        
        # Create a DataFrame with all baseline averages
        if test_avg_by_baseline:
            steps = non_df['Step'].values
            test_df = pd.DataFrame({'Step': steps})
            
            for baseline, avg_series in test_avg_by_baseline.items():
                # Ensure the series has the same length as steps
                if len(avg_series) < len(steps):
                    # Pad with NaN
                    padded = pd.Series([np.nan] * len(steps), index=range(len(steps)))
                    padded.iloc[:len(avg_series)] = avg_series
                    test_avg_by_baseline[baseline] = padded
                    
                test_df[baseline] = test_avg_by_baseline[baseline].values
            
            # Renumber steps from 1 to n
            test_df['Step'] = range(1, len(test_df) + 1)
    
    return train_df, train_var_by_baseline, test_df, test_var_by_baseline, best_configs

def process_results(base_dir):
    """Process all results and create output dictionaries."""
    csv_files = find_csv_files(base_dir)
    
    train_dfs = {}
    test_dfs = {}
    train_variances = {}
    test_variances = {}
    best_configs = {}
    
    for task in csv_files:
        train_df, train_var, test_df, test_var, task_best_configs = process_task_baselines(csv_files, task)
        
        if not train_df.empty:
            train_dfs[task] = train_df
            train_variances[task] = train_var
            
        if not test_df.empty:
            test_dfs[task] = test_df
            test_variances[task] = test_var
            
        best_configs[task] = task_best_configs
    
    return {
        'train_dfs': train_dfs,
        'test_dfs': test_dfs,
        'train_variances': train_variances,
        'test_variances': test_variances,
        'best_configs': best_configs
    }

def visualize_results(results, output_dir='./plots'):
    """Visualize the results."""
    os.makedirs(output_dir, exist_ok=True)
    
    for task, df in results['train_dfs'].items():
        plt.figure(figsize=(12, 6))
        for baseline in df.columns:
            if baseline != 'Step':
                plt.plot(df['Step'], df[baseline], label=baseline)
        
        plt.title(f'Train Accuracy - {task}')
        plt.xlabel('Step')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(output_dir, f'{task}_train_accuracy.png'))
        plt.close()
    
    for task, df in results['test_dfs'].items():
        plt.figure(figsize=(12, 6))
        for baseline in df.columns:
            if baseline != 'Step':
                plt.plot(df['Step'], df[baseline], label=baseline)
        
        plt.title(f'Test Accuracy - {task}')
        plt.xlabel('Step')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(output_dir, f'{task}_test_accuracy.png'))
        plt.close()

if __name__ == "__main__":
    # Path to the results directory
    results_dir = r'C:\Users\khash\OneDrive\Desktop\Research-Coding\17\result_analysis\results'
    
    # Process the results
    results = process_results(results_dir)
    
    # Print the best configurations
    print("Best configurations by task and baseline:")
    for task, configs in results['best_configs'].items():
        print(f"\n{task}:")
        for baseline, config in configs.items():
            print(f"  {baseline}: {config}")
    
    # Visualize the results
    visualize_results(results)
    
    # Save the results to CSV files
    output_dir = "./processed_results"
    os.makedirs(output_dir, exist_ok=True)
    
    for task, df in results['train_dfs'].items():
        df.to_csv(os.path.join(output_dir, f"{task}_train.csv"), index=False)
    
    for task, df in results['test_dfs'].items():
        df.to_csv(os.path.join(output_dir, f"{task}_test.csv"), index=False)
    
    # Save the variance information
    for task, variances in results['train_variances'].items():
        var_df = pd.DataFrame({'Step': results['train_dfs'][task]['Step']})
        for baseline, var_series in variances.items():
            # Make sure the variance series length matches the DataFrame index length
            if len(var_series) < len(var_df):
                # Pad with NaNs if variance series is shorter
                padded_values = np.full(len(var_df), np.nan)
                padded_values[:len(var_series)] = var_series.values
                var_df[baseline] = padded_values
            else:
                # Trim if variance series is longer
                var_df[baseline] = var_series.values[:len(var_df)]
        var_df.to_csv(os.path.join(output_dir, f"{task}_train_variance.csv"), index=False)
    
    for task, variances in results['test_variances'].items():
        var_df = pd.DataFrame({'Step': results['test_dfs'][task]['Step']})
        for baseline, var_series in variances.items():
            # Make sure the variance series length matches the DataFrame index length
            if len(var_series) < len(var_df):
                # Pad with NaNs if variance series is shorter
                padded_values = np.full(len(var_df), np.nan)
                padded_values[:len(var_series)] = var_series.values
                var_df[baseline] = padded_values
            else:
                # Trim if variance series is longer
                var_df[baseline] = var_series.values[:len(var_df)]
        var_df.to_csv(os.path.join(output_dir, f"{task}_test_variance.csv"), index=False)
    
    # Save the best configurations
    # Save the best configurations as JSON
    with open(os.path.join(output_dir, "best_configs.json"), 'w') as f:
        json.dump(results['best_configs'], f, indent=2)