import os
import numpy as np
import pandas as pd
ROUND_2 = True

def truncate_to_two_decimals(number):
    """
    Takes a float and truncates it to 2 decimal places (no rounding).
    
    Args:
        number (float): The number to truncate
        
    Returns:
        float: The number truncated to 2 decimal places
    """
    # Convert to string with many decimal places, then take only what we need
    if number < 0.01:
        return 0.00
    str_number = str(float(number))
    parts = str_number.split('.')
    
    # Handle case where there's no decimal part or decimal part is shorter than 2
    if len(parts) == 1 or len(parts[1]) <= 2:
        return float(number)
    
    # Truncate to 2 decimal places
    return float(parts[0] + '.' + parts[1][:2])
def compute_bat(backward_data, scratch_data, T, b_values):
    """Compute Backward Adaptation Transfer (BAT) for different b percentage values"""
    bat_results = {}
    
    # Convert percentages to actual batch counts
    number_of_backward_time_steps = min(backward_data.shape[2], scratch_data.shape[2])
    b_indices = [int(b * number_of_backward_time_steps / 100) for b in b_values]
    
    for i, b in enumerate(b_values):
        b_idx = max(1, b_indices[i])  # Ensure at least 1 batch
        bat_sum = 0
        valid_count = 0
        
        for k in range(T):
            try:
                # A_T,b,k is accuracy on task k after training on all T tasks and fine-tuning on task k
                a_t_b_k = backward_data[-1][k][b_idx-1]
                
                # A'_k,b,k is accuracy of random initialized model trained for b batches on task k
                a_prime_k_b_k = scratch_data[0][k][b_idx-1]
                
                bat_sum += (a_t_b_k - a_prime_k_b_k)
                valid_count += 1
            except (IndexError, TypeError) as e:
                print(f"  [BAT Error] Task {k}, b_idx {b_idx-1}: {e}")
                continue
                
        if valid_count > 0:
            # result = round(bat_sum / valid_count, 2)

            try:
                result = bat_sum / valid_count
                if ROUND_2:
                    result = truncate_to_two_decimals(result)
                bat_results[b] = result
            except:
                bat_results[b] = float('nan')
        else:
            print(f"  [BAT Warning] No valid data for b={b}%")
            bat_results[b] = float('nan')
            
    return bat_results

def compute_fat(forward_data, scratch_data, T, b_values):
    """Compute Forward Adaptation Transfer (FAT) for different b percentage values"""
    fat_results = {}
    
    # Convert percentages to actual batch counts
    b_indices = [int((b * forward_data.shape[2]) / 100) for b in b_values]
    
    for i, b in enumerate(b_values):
        b_idx = max(1, b_indices[i])  # Ensure at least 1 batch
        fat_sum = 0
        valid_count = 0
        
        for k in range(1, T):
            try:
                # A_k-1,b,k is accuracy on task k after training on k-1 tasks and fine-tuning on task k
                a_k_minus_1_b_k = forward_data[0][k][b_idx-1]
                
                # A'_k,b,k is accuracy of random initialized model trained for b batches on task k
                a_prime_k_b_k = scratch_data[0][k][b_idx-1]
                
                fat_sum += (a_k_minus_1_b_k - a_prime_k_b_k)
                valid_count += 1
            except (IndexError, TypeError) as e:
                print(f"  [FAT Error] Task {k}, b_idx {b_idx-1}: {e}")
                continue
                
        if valid_count > 0:
            # result = round(fat_sum / (valid_count), 2)

            try:
                result = fat_sum / (valid_count)
                if ROUND_2:
                    result = truncate_to_two_decimals(result)
                    
                fat_results[b] = result
            except:
                fat_results[b] = float('nan')
        else:
            print(f"  [FAT Warning] No valid data for b={b}%")
            fat_results[b] = float('nan')
            
    return fat_results

def compute_lca_backward(backward_data, T, B):
    """Compute Learning Curve Area (LCA) for backward adaptation"""
    z_values = []
    errors_count = 0
    
    for b in range(B+1):
        z_b = 0
        valid_count = 0
        
        for k in range(T):
            try:
                # A_T,b,k is accuracy on task k after training on all T tasks and fine-tuning on task k
                a_t_b_k = backward_data[-1][k][b]
                z_b += a_t_b_k
                valid_count += 1
            except (IndexError, TypeError) as e:
                errors_count += 1
                if errors_count <= 3:  # Limit error messages to avoid flooding
                    print(f"  [LCAB Error] Task {k}, b {b}: {e}")
                continue
                
        if valid_count > 0 and z_b is not np.nan:
            try:
                result = z_b / valid_count
                # result = round(result, 2)
                if ROUND_2:
                    result = truncate_to_two_decimals(result)
                z_values.append(result)
            except:
                z_values.append(float('nan'))
        else:
            z_values.append(float('nan'))
    
    # Calculate LCA_B as the average of Z^backward values
    z_values = np.array(z_values)
    valid_indices = ~np.isnan(z_values)
    
    if np.any(valid_indices):
        lca_b = np.mean(z_values[valid_indices])
        if errors_count > 3:
            print(f"  [LCAB Notice] {errors_count} total errors encountered, only first 3 shown")
    else:
        print(f"  [LCAB Warning] No valid data for any b value")
        lca_b = float('nan')
        
    return lca_b

def compute_lca_forward(forward_data, T, B):
    """Compute Learning Curve Area (LCA) for forward adaptation"""
    z_values = []
    errors_count = 0
    
    for b in range(B+1):
        z_b = 0
        valid_count = 0
        
        for k in range(1, T):
            try:
                # A_k-1,b,k is accuracy on task k after training on k-1 tasks and fine-tuning on task k
                a_k_minus_1_b_k = forward_data[0][k][b]
                z_b += a_k_minus_1_b_k
                valid_count += 1
            except (IndexError, TypeError) as e:
                errors_count += 1
                if errors_count <= 3:  # Limit error messages to avoid flooding
                    print(f"  [LCAF Error] Task {k}, b {b}: {e}")
                continue
                
        if valid_count > 0:
            # result = round(z_b / valid_count, 2)
            
            try:
                result = z_b / valid_count
                if ROUND_2:
                    result = truncate_to_two_decimals(result)
                z_values.append(result)
            except:
                z_values.append(float('nan'))
            
        else:
            z_values.append(float('nan'))
    
    # Calculate LCA_F as the average of Z^forward values
    z_values = np.array(z_values)
    valid_indices = ~np.isnan(z_values)
    
    if np.any(valid_indices):
        lca_f = np.mean(z_values[valid_indices])
        if errors_count > 3:
            print(f"  [LCAF Notice] {errors_count} total errors encountered, only first 3 shown")
    else:
        print(f"  [LCAF Warning] No valid data for any b value")
        lca_f = float('nan')
        
    return lca_f

def load_data(benchmark, baseline, mode, number_of_rows):
    """Load data files for the given benchmark and baseline"""
    T = number_of_rows[benchmark]
    
    # Paths for the data files
    main_path = r'C:\Users\khash\OneDrive\Desktop\Research-Coding\17\processed_results'
    
    # Handle special agent name for MAML
    agent_name = f'{baseline}Agent'
    if baseline == 'MAML':
        agent_name = 'MAML'

    backward_path = os.path.join(
        main_path,
        "bwt_matrices_backward",
        benchmark,
        baseline,
        mode,
        f"{baseline}.npy"
    )

    forward_path = os.path.join(
        main_path,
        "fwt_matrices",
        mode,
        benchmark,
        f"{agent_name}",
        f"{agent_name}.npy"
    )

    # For backward adaptation, we use baseline_interpolate as the scratch
    backward_scratch_path = os.path.join(
        main_path,
        "scratch_matrices",
        mode,
        benchmark,
        f"{baseline}Agent",
        f"{agent_name}_interpolate.npy"
    )

    # For forward adaptation, we use baseline.npy as the scratch
    forward_scratch_path = os.path.join(
        main_path,
        "scratch_matrices",
        mode,
        benchmark,
        f"{baseline}Agent",
        f"{agent_name}.npy"
    )
    
    # Load data files with detailed error reporting
    try:
        backward_data = np.load(backward_path)
        print(f"✓ Loaded backward data: {os.path.basename(backward_path)}, shape: {backward_data.shape}")
    except FileNotFoundError:
        print(f"✗ Backward data file not found: {backward_path}")
        backward_data = None
    except Exception as e:
        print(f"✗ Error loading backward data: {backward_path}")
        print(f"  Error details: {e}")
        backward_data = None
        
    try:
        forward_data = np.load(forward_path)
        print(f"✓ Loaded forward data: {os.path.basename(forward_path)}, shape: {forward_data.shape}")
    except FileNotFoundError:
        print(f"✗ Forward data file not found: {forward_path}")
        forward_data = None
    except Exception as e:
        print(f"✗ Error loading forward data: {forward_path}")
        print(f"  Error details: {e}")
        forward_data = None
        
    try:
        backward_scratch_data = np.load(backward_scratch_path)
        print(f"✓ Loaded backward scratch data: {os.path.basename(backward_scratch_path)}, shape: {backward_scratch_data.shape}")
    except FileNotFoundError:
        print(f"✗ Backward scratch data file not found: {backward_scratch_path}")
        backward_scratch_data = None
    except Exception as e:
        print(f"✗ Error loading backward scratch data: {backward_scratch_path}")
        print(f"  Error details: {e}")
        backward_scratch_data = None
        
    try:
        forward_scratch_data = np.load(forward_scratch_path)
        print(f"✓ Loaded forward scratch data: {os.path.basename(forward_scratch_path)}, shape: {forward_scratch_data.shape}")
    except FileNotFoundError:
        print(f"✗ Forward scratch data file not found: {forward_scratch_path}")
        forward_scratch_data = None
    except Exception as e:
        print(f"✗ Error loading forward scratch data: {forward_scratch_path}")
        print(f"  Error details: {e}")
        forward_scratch_data = None
        
    return backward_data, forward_data, backward_scratch_data, forward_scratch_data

def format_value(value, is_best=False):
    """Format a value for the LaTeX table with optional bold markup"""
    if np.isnan(value):
        return "-"
    elif is_best:
        return f"\\textbf{{{value:.2f}}}"
    else:
        return f"{value:.2f}"

def create_consolidated_latex_table(benchmarks, baselines, number_of_rows, mode, name_dict, bat_fat_value="100"):
    """
    Create a consolidated LaTeX table with results for all benchmarks and baselines
    
    Parameters:
    -----------
    benchmarks : list
        List of benchmark names
    baselines : list
        List of baseline names
    number_of_rows : dict
        Dictionary mapping benchmarks to number of rows/tasks
    mode : str
        Mode for data loading
    name_dict : dict
        Dictionary mapping baseline names to display names
    bat_fat_value : str, optional
        The BAT/FAT percentage value to report. Can be "10", "25", "50", "75", "100", or "mean"
        
    Returns:
    --------
    str
        LaTeX table as a string
    """
    # Define the b_values for BAT and FAT calculations
    b_values = [1,2, 5, 10, 20, 25, 50, 75, 100]
    
    # Dictionary to store all results
    all_results = {}
    
    # Process each benchmark and baseline combination
    for benchmark in benchmarks:
        print(f"\n{'='*80}")
        print(f"Processing benchmark: {benchmark}")
        print(f"{'='*80}")
        
        T = number_of_rows[benchmark]
        benchmark_results = {}
        
        for baseline in baselines:
            print(f"\n{'-'*60}")
            print(f"Processing baseline: {baseline}")
            print(f"{'-'*60}")
            
            # Load data
            backward_data, forward_data, backward_scratch_data, forward_scratch_data = load_data(benchmark, baseline, mode, number_of_rows)
            
            # Initialize results dictionary
            results = {
                'BAT': {},
                'FAT': {},
                'LCAB': float('nan'),
                'LCAF': float('nan')
            }
            
            # Compute BAT for different b values if data is available
            if backward_data is not None and backward_scratch_data is not None:
                print(f"\nComputing BAT metrics...")
                try:
                    results['BAT'] = compute_bat(backward_data, backward_scratch_data, T, b_values)
                    
                    print(f"✓ BAT computation completed")
                    
                    # Compute mean BAT if requested
                    if bat_fat_value == "mean":
                        valid_values = [results['BAT'][b] for b in b_values if b in results['BAT'] and not np.isnan(results['BAT'][b])]
                        if valid_values:
                            results['BAT_mean'] = np.mean(valid_values)
                        else:
                            results['BAT_mean'] = float('nan')
                except Exception as e:
                    print(f"✗ Error in BAT computation:")
                    print(f"  {str(e)}")
                    print(traceback.format_exc())
                
                # Compute LCAB if data is available
                print(f"\nComputing LCAB metric...")
                try:
                    B = backward_data.shape[2] - 1  # Number of batches
                    results['LCAB'] = compute_lca_backward(backward_data, T, B)
                    print(f"✓ LCAB computation completed: {results['LCAB']:.4f}")
                    
                    
                except Exception as e:
                    print(f"✗ Error in LCAB computation:")
                    print(f"  {str(e)}")
                    print(traceback.format_exc())
            
            # Compute FAT for different b values if data is available

            if forward_data is not None:
                # Compute LCAF if data is available
                print(f"\nComputing LCAF metric...")
                try:
                    
                    B = forward_data.shape[2] - 1  # Number of batches
                    results['LCAF'] = compute_lca_forward(forward_data, T, B)
                    print(f"✓ LCAF computation completed: {results['LCAF']:.4f}")
                    
                except Exception as e:
                    print(f"✗ Error in LCAF computation:")
                    print(f"  {str(e)}")
                    print(traceback.format_exc())
                if  forward_scratch_data is not None:
                    print(f"\nComputing FAT metrics...")
                    b_values_chosen = [b for b in b_values]
                    
                    try:
                        was_100_original = False
                        
                        # replace the max values with the last non_nan
                            
                        non_nan_indices_scratch = np.where(~np.isnan(forward_scratch_data[0,-1,:]))[0][-1]
                        non_nan_indices_normal = np.where(~np.isnan(forward_data[0,-1,:]))[0][-1]
                        non_nan_indice = min(non_nan_indices_normal, non_nan_indices_scratch)
                        non_nan_indice_percentage = int((non_nan_indice * 100) / forward_scratch_data.shape[2])
                        for idx in range(len(b_values_chosen)):
                            if non_nan_indice_percentage < b_values_chosen[idx]:
                                b_values_chosen[idx] = non_nan_indice_percentage
                                was_100_original = True
                        # index_100 = b_values_chosen.index(100)
                        # b_values_chosen[index_100] = int((non_nan_indice / forward_data.shape[2]) * 100)
                        
                            
                        
                        results['FAT'] = compute_fat(forward_data, forward_scratch_data, T, b_values_chosen)
                        # if baseline == 'NeuroSync':
                        #     breakpoint()
                        
                        if was_100_original:
                            max_batch_percent = max(results['FAT'].keys())
                            for b in b_values:
                                if b > non_nan_indice:
                                    results['FAT'][b] = results['FAT'][max_batch_percent]
                            
                            
                            if max_batch_percent not in b_values:
                                del results['FAT'][max_batch_percent]
                        
                        print(f"✓ FAT computation completed")

                        # if baseline == 'ReDo':
                        #     breakpoint()
                        
                        # Compute mean FAT if requested
                        if bat_fat_value == "mean":
                            valid_values = [results['FAT'][b] for b in b_values_chosen if b in results['FAT'] and not np.isnan(results['FAT'][b])]
                            if valid_values:
                                results['FAT_mean'] = np.mean(valid_values)
                            else:
                                results['FAT_mean'] = float('nan')
                        
                    except Exception as e:

                        print(f"✗ Error in FAT computation:")
                        print(f"  {str(e)}")
                        
                    
                    # Compute LCAF if data is available
                
            
            benchmark_results[baseline] = results
            
        all_results[benchmark] = benchmark_results
    
    # Find best values for each metric across all benchmarks and baselines
    best_values = {}
    for benchmark in benchmarks:
        best_values[benchmark] = {
            'BAT': float('-inf'),
            'FAT': float('-inf'),
            'LCAB': float('-inf'),
            'LCAF': float('-inf')
        }
        
        for baseline in baselines:
            results = all_results[benchmark][baseline]
            
            # For BAT, use either 100% or mean based on bat_fat_value
            if bat_fat_value == "mean":
                if 'BAT_mean' in results and not np.isnan(results['BAT_mean']):
                    best_values[benchmark]['BAT'] = max(best_values[benchmark]['BAT'], results['BAT_mean'])
            else:
                b_value = int(bat_fat_value)
                if b_value in results['BAT'] and not np.isnan(results['BAT'][b_value]):
                    best_values[benchmark]['BAT'] = max(best_values[benchmark]['BAT'], results['BAT'][b_value])
            
            # For FAT, use either 100% or mean based on bat_fat_value
            if bat_fat_value == "mean":
                if 'FAT_mean' in results and not np.isnan(results['FAT_mean']):
                    best_values[benchmark]['FAT'] = max(best_values[benchmark]['FAT'], results['FAT_mean'])
            else:
                b_value = int(bat_fat_value)
                if b_value in results['FAT'] and not np.isnan(results['FAT'][b_value]):
                    best_values[benchmark]['FAT'] = max(best_values[benchmark]['FAT'], results['FAT'][b_value])
            
            # For LCAB
            if not np.isnan(results['LCAB']):
                best_values[benchmark]['LCAB'] = max(best_values[benchmark]['LCAB'], results['LCAB'])
            
            # For LCAF
            if not np.isnan(results['LCAF']):
                best_values[benchmark]['LCAF'] = max(best_values[benchmark]['LCAF'], results['LCAF'])
    
    # Create table header
    bat_fat_label = "mean" if bat_fat_value == "mean" else f"{bat_fat_value}\\%"
    
    table = f"\\begin{{table}}[ht]\n"
    table += f"\\centering\n"
    table += f"\\caption{{Consolidated Results Across All Benchmarks}}\n"
    table += f"\\resizebox{{\\textwidth}}{{!}}{{%\n"
    table += f"\\begin{{tabular}}{{l|{'|'.join(['cccc' for _ in benchmarks])}}}\n"
    table += f"\\hline\n"
    
    # First row with benchmark names spanning 4 columns each
    table += f"\\multirow{{2}}{{*}}{{Baseline}} & "
    for benchmark in benchmarks:
        table += f"\\multicolumn{{4}}{{c|}}{{\\textbf{{{benchmark}}}}} & "
    table = table[:-2] + "\\\\\n"  # Remove last '& ' and add newline
    
    # Second row with metric names
    table += f" & "
    for _ in benchmarks:
        table += f"BAT$_{{{bat_fat_label}}}$ & FAT$_{{{bat_fat_label}}}$ & LCA$_B$ & LCA$_F$ & "
    table = table[:-2] + "\\\\\n"  # Remove last '& ' and add newline
    
    table += f"\\hline\n"
    
    # Add results for each baseline
    for baseline in baselines:
        display_name = name_dict.get(baseline, baseline)
        row = f"{display_name} & "
        
        for benchmark in benchmarks:
            results = all_results[benchmark][baseline]
            
            # BAT value
            if bat_fat_value == "mean":
                if 'BAT_mean' in results and not np.isnan(results['BAT_mean']):
                    is_best = abs(results['BAT_mean'] - best_values[benchmark]['BAT']) < 1e-6
                    row += format_value(results['BAT_mean'], is_best) + " & "
                else:
                    row += "- & "
            else:
                b_value = int(bat_fat_value)
                if b_value in results['BAT'] and not np.isnan(results['BAT'][b_value]):
                    is_best = abs(results['BAT'][b_value] - best_values[benchmark]['BAT']) < 1e-6
                    row += format_value(results['BAT'][b_value], is_best) + " & "
                else:
                    row += "- & "
            
            # FAT value
            if bat_fat_value == "mean":
                if 'FAT_mean' in results and not np.isnan(results['FAT_mean']):
                    is_best = abs(results['FAT_mean'] - best_values[benchmark]['FAT']) < 1e-6
                    row += format_value(results['FAT_mean'], is_best) + " & "
                else:
                    row += "- & "
            else:
                b_value = int(bat_fat_value)
                if b_value in results['FAT'] and not np.isnan(results['FAT'][b_value]):
                    is_best = abs(results['FAT'][b_value] - best_values[benchmark]['FAT']) < 1e-6
                    row += format_value(results['FAT'][b_value], is_best) + " & "
                else:
                    row += "- & "
            
            # LCAB value
            is_best_lcab = abs(results['LCAB'] - best_values[benchmark]['LCAB']) < 1e-6 if not np.isnan(results['LCAB']) else False
            row += format_value(results['LCAB'], is_best_lcab) + " & "
            
            # LCAF value
            is_best_lcaf = abs(results['LCAF'] - best_values[benchmark]['LCAF']) < 1e-6 if not np.isnan(results['LCAF']) else False
            row += format_value(results['LCAF'], is_best_lcaf) + " & "
        
        # Remove last '& ' and add newline
        row = row[:-2] + "\\\\\n"
        table += row
    
    # Complete the table
    table += "\\hline\n"
    table += "\\end{tabular}%\n"
    table += "}\n"
    table += "\\label{tab:consolidated-results}\n"
    table += "\\end{table}\n"
    
    return table

# Modified main function to include the consolidated table

def create_latex_table(benchmarks, baselines, number_of_rows, mode, name_dict):
    """Create a LaTeX table with results for all benchmarks and baselines"""
    b_values = [10, 25, 50, 75, 100]  # BAT and FAT percentage values
    
    # Dictionary to store all results
    all_results = {}
    
    # Process each benchmark and baseline combination
    for benchmark in benchmarks:
        print(f"\n{'='*80}")
        print(f"Processing benchmark: {benchmark}")
        print(f"{'='*80}")
        
        T = number_of_rows[benchmark]
        benchmark_results = {}
        
        for baseline in baselines:
            print(f"\n{'-'*60}")
            print(f"Processing baseline: {baseline}")
            print(f"{'-'*60}")
            
            # Load data
            backward_data, forward_data, backward_scratch_data, forward_scratch_data = load_data(benchmark, baseline, mode, number_of_rows)
            
            # Initialize results dictionary
            results = {
                'BAT': {},
                'FAT': {},
                'LCAB': float('nan'),
                'LCAF': float('nan')
            }
            
            # Compute BAT for different b values if data is available
            if backward_data is not None:
                # Compute LCAB if data is available
                print(f"\nComputing LCAB metric...")
                try:
                    B = backward_data.shape[2] - 1  # Number of batches
                    results['LCAB'] = compute_lca_backward(backward_data, T, B)
                    print(f"✓ LCAB computation completed: {results['LCAB']:.4f}")
                except Exception as e:
                    print(f"✗ Error in LCAB computation:")
                    print(f"  {str(e)}")
                    print(traceback.format_exc())
                
                if backward_data is not None and backward_scratch_data is not None:
                    print(f"\nComputing BAT metrics...")
                    try:
                        results['BAT'] = compute_bat(backward_data, backward_scratch_data, T, b_values)
                        print(f"✓ BAT computation completed")
                    except Exception as e:
                        print(f"✗ Error in BAT computation:")
                        print(f"  {str(e)}")
                        print(traceback.format_exc())
                
                
            if backward_data is None or backward_scratch_data is None:
                missing = []
                cant_compute = ''
                if backward_data is None:
                    missing.append("backward_data")
                    cant_compute = 'BAT/LCAB'
                if backward_scratch_data is None:
                    missing.append("backward_scratch_data")
                    if backward_data is None:
                        cant_compute = 'BAT/LCAB'
                    else:
                        cant_compute = 'BAT'
                print(f"✗ Cannot compute BAT/LCAB: Missing data files: {', '.join(missing)}")
            



            # Compute FAT for different b values if data is available

            if forward_data is not None:
                # Compute LCAF if data is available
                print(f"\nComputing LCAF metric...")
                try:
                    B = forward_data.shape[2] - 1  # Number of batches
                    results['LCAF'] = compute_lca_forward(forward_data, T, B)
                    print(f"✓ LCAF computation completed: {results['LCAF']:.4f}")
                except Exception as e:
                    print(f"✗ Error in LCAF computation:")
                    print(f"  {str(e)}")
                    print(traceback.format_exc())
                    
                if forward_data is not None and forward_scratch_data is not None:
                    print(f"\nComputing FAT metrics...")
                    try:
                        results['FAT'] = compute_fat(forward_data, forward_scratch_data, T, b_values)
                        print(f"✓ FAT computation completed")
                    except Exception as e:
                        print(f"✗ Error in FAT computation:")
                        print(f"  {str(e)}")
                        print(traceback.format_exc())
                
                
            if forward_data is None or forward_scratch_data is None:
                missing = []
                cant_compute = ''
                if forward_data is None:
                    missing.append("forward_data")
                    cant_compute = 'FAT/LCAF'
                if forward_scratch_data is None:
                    missing.append("forward_scratch_data")
                    if forward_data is None:
                        cant_compute = 'FAT/LCAF'
                    else:
                        cant_compute = 'FAT'
                print(f"✗ Cannot compute FAT/LCAF: Missing data files: {', '.join(missing)}")
            
            # Print summary of results
            print("\nResults summary:")
            for b in b_values:
                if b in results['BAT']:
                    print(f"  BAT {b}%: {results['BAT'][b]:.4f}")
                else:
                    print(f"  BAT {b}%: -")
            for b in b_values:
                if b in results['FAT']:
                    print(f"  FAT {b}%: {results['FAT'][b]:.4f}")
                else:
                    print(f"  FAT {b}%: -")
            print(f"  LCAB: {results['LCAB']:.4f}" if not np.isnan(results['LCAB']) else "  LCAB: -")
            print(f"  LCAF: {results['LCAF']:.4f}" if not np.isnan(results['LCAF']) else "  LCAF: -")
            
            benchmark_results[baseline] = results
        
        all_results[benchmark] = benchmark_results
    
    # Create LaTeX tables for each benchmark
    latex_tables = {}
    
    for benchmark in benchmarks:
        print(f"\n{'='*80}")
        print(f"Creating LaTeX table for: {benchmark}")
        print(f"{'='*80}")
        
        benchmark_results = all_results[benchmark]
        
        # Find best values for each metric
        best_values = {
            'BAT': {b: float('-inf') for b in b_values},
            'FAT': {b: float('-inf') for b in b_values},
            'LCAB': float('-inf'),
            'LCAF': float('-inf')
        }
        
        for baseline in baselines:
            results = benchmark_results[baseline]
            
            # Find best BAT values
            for b in b_values:
                if b in results['BAT'] and not np.isnan(results['BAT'][b]):
                    best_values['BAT'][b] = max(best_values['BAT'][b], results['BAT'][b])
            
            # Find best FAT values
            for b in b_values:
                if b in results['FAT'] and not np.isnan(results['FAT'][b]):
                    best_values['FAT'][b] = max(best_values['FAT'][b], results['FAT'][b])
            
            # Find best LCAB value
            if not np.isnan(results['LCAB']):
                best_values['LCAB'] = max(best_values['LCAB'], results['LCAB'])
            
            # Find best LCAF value
            if not np.isnan(results['LCAF']):
                best_values['LCAF'] = max(best_values['LCAF'], results['LCAF'])
        
        print("\nBest values summary:")
        for b in b_values:
            if best_values['BAT'][b] > float('-inf'):
                print(f"  Best BAT {b}%: {best_values['BAT'][b]:.4f}")
            else:
                print(f"  Best BAT {b}%: No valid data")
        for b in b_values:
            if best_values['FAT'][b] > float('-inf'):
                print(f"  Best FAT {b}%: {best_values['FAT'][b]:.4f}")
            else:
                print(f"  Best FAT {b}%: No valid data")
        print(f"  Best LCAB: {best_values['LCAB']:.4f}" if best_values['LCAB'] > float('-inf') else "  Best LCAB: No valid data")
        print(f"  Best LCAF: {best_values['LCAF']:.4f}" if best_values['LCAF'] > float('-inf') else "  Best LCAF: No valid data")
        
        # Create table header
        table = f"\\begin{{table}}[ht]\n"
        table += f"\\centering\n"
        table += f"\\caption{{{benchmark} Results}}\n"
        # <-- HERE we start the resizebox
        table += "  \\resizebox{\\textwidth}{!}{%\n"
        table += "    \\begin{tabular}{l|ccccc|ccccc|cc}\n"
        table += "      \\hline\n"
        table += "      \\multirow{2}{*}{Baseline} & "
        table += "\\multicolumn{5}{c|}{BAT (\\%)} & "
        table += "\\multicolumn{5}{c|}{FAT (\\%)} & "
        table += "\\multirow{2}{*}{LCA$_B$} & \\multirow{2}{*}{LCA$_F$} \\\\\n"
        table += "      & 10\\% & 25\\% & 50\\% & 75\\% & 100\\% & "
        table += "10\\% & 25\\% & 50\\% & 75\\% & 100\\% & & \\\\\n"
        table += "      \\hline\n"
        
        # Add results for each baseline
        for baseline in baselines:
            results = benchmark_results[baseline]
            display_name = name_dict.get(baseline, baseline)
            
            row = f"{display_name} & "
            
            # Add BAT values
            for b in b_values:
                if b in results['BAT'] and not np.isnan(results['BAT'][b]):
                    is_best = abs(results['BAT'][b] - best_values['BAT'][b]) < 1e-6
                    row += format_value(results['BAT'][b], is_best) + " & "
                else:
                    row += "- & "
            
            # Add FAT values
            for b in b_values:
                if b in results['FAT'] and not np.isnan(results['FAT'][b]):
                    is_best = abs(results['FAT'][b] - best_values['FAT'][b]) < 1e-6
                    row += format_value(results['FAT'][b], is_best) + " & "
                else:
                    row += "- & "
            
            # Add LCAB value
            is_best_lcab = abs(results['LCAB'] - best_values['LCAB']) < 1e-6 if not np.isnan(results['LCAB']) else False
            row += format_value(results['LCAB'], is_best_lcab) + " & "
            
            # Add LCAF value
            is_best_lcaf = abs(results['LCAF'] - best_values['LCAF']) < 1e-6 if not np.isnan(results['LCAF']) else False
            row += format_value(results['LCAF'], is_best_lcaf) + " \\\\\n"
            
            table += row
        
        # Complete the table
        table += "      \\hline\n"
        table += "    \\end{tabular}%\n"
        # close the resizebox
        table += "  }\n"
        table += f"  \\label{{tab:{benchmark.lower().replace('_','-')}}}\n"
        table += "\\end{table}\n"
        
        latex_tables[benchmark] = table
        print(f"✓ Table for {benchmark} created successfully")
    
    return latex_tables


if __name__ == "__main__":
    # benchmarks = ['random_MNIST', 'random_label_cifar10', 'shuffle_cifar10', 'permuted_MNIST', 'continual_cifar100', 'continual_imagenet']
    # baselines = ['Base', 'CBP', 'CReLU', 'DeepF', 'EWC', 'L2', 'L2Init', 'LayerNorm', 'NeuroSync', 'PReLU', 'ReDo', 'L2InitPlusEWC']

    number_of_rows = {
        'random_MNIST' : 30,
        'random_label_cifar10' : 30,
        'shuffle_cifar10' : 30,
        'permuted_MNIST': 25,
        'continual_cifar100' : 20,
        'continual_imagenet': 100,    
    }
    
    mode = 'train'
    benchmarks = [
        'random_label_cifar10',
        'random_MNIST',
        'shuffle_cifar10','permuted_MNIST', 'continual_cifar100', 'continual_imagenet']
    baselines = ['Base', 'CBP', 'CReLU', 'EWC', 'NeuroSync', 'PReLU', 'ReDo', 'L2InitPlusEWC', 'MAML']
    name_dict = {'Base': 'Base', 'CBP': 'CBP', 'CReLU': 'CReLU',
                    'DeepF': 'DeepF', 'EWC' : 'EWC', 'L2': 'L2', 'L2Init': 'L2Init', 'LayerNorm': 'LayerNorm',
                    'NeuroSync': 'NeuroSync', 'PReLU': 'PReLU', 'ReDo': 'ReDo', 'Scratch': 'Scratch',
                    "L2InitPlusEWC" : 'L2Init + EWC', 
                    'MAML': 'MAML'}
    
    # Create LaTeX tables
    benchmarks = [
        'permuted_MNIST'
       ]  

    baselines = ['CBP', 'CReLU', 'NeuroSync', 'ReDo', 'L2InitPlusEWC']
    consolidated_table = create_consolidated_latex_table(benchmarks, baselines, number_of_rows, mode, name_dict, bat_fat_value="50")
    
    #consolidated_table = create_latex_table(benchmarks, baselines, number_of_rows, mode, name_dict)
    # Write LaTeX tables to files
    os.makedirs("output_tables", exist_ok=True)
    
    # for benchmark, table in latex_tables.items():
    #     with open(f"output/{benchmark}_results.tex", "w") as f:
    #         f.write(table)
    
    # Create a combined file with all tables
    print("\nConsolidated LaTeX Table:")
    print(consolidated_table)
    
    #Save the consolidated table to a file
    with open("output_tables/consolidated_results_table.tex", "w") as f:
        f.write(consolidated_table)
    
    # with open("output_tables/consolidated_results_table.tex", "w") as f:
    #     for benchmark, table in consolidated_table.items():
    #         f.write(f"% {benchmark} Results\n")
    #         f.write(table)
    #         f.write("\n\n")
    
    print("LaTeX tables generated successfully!")