
import os
import numpy as np
import glob
from typing import List, Dict, Tuple, Optional, Any
import re

def load_matrix(file_path: str, max_tasks: int) -> Optional[np.ndarray]:
    """
    Load a matrix from a .npy file and trim it if necessary.
    
    Args:
        file_path: Path to the .npy file
        max_tasks: Maximum number of tasks to consider
        
    Returns:
        The loaded matrix or None if loading fails
    """
    try:
        matrix = np.load(file_path)
        # Trim the matrix if it has more tasks than specified
        if matrix.shape[0] > max_tasks:
            matrix = matrix[:max_tasks, :max_tasks]
        return matrix
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None

def compute_bwt(matrix: np.ndarray) -> float:
    """
    Compute Backwards Transfer (BWT) metric.
    
    Note: In the provided matrices, A[i,j] already represents the difference:
    A[i,j] = (accuracy on task j after training on task i) - (accuracy on task j after training on task j)
    
    Args:
        matrix: Matrix where matrix[i,j] is the difference in accuracy on task j
               after training on task i compared to accuracy right after training on task j
        
    Returns:
        BWT value
    """
    K = matrix.shape[0]
    if K <= 1:
        return 0.0  # No forgetting with just one task
    
    # Since A[i,j] is already the difference, we just need to average the last row
    # for all tasks except the last one
    bwt_values = matrix[K-1, :K-1]
    return np.mean(bwt_values)

def compute_one_step_bwt(matrix: np.ndarray) -> float:
    """
    Compute One Step Backwards Transfer metric.
    
    Note: In the provided matrices, A[i,j] already represents the difference:
    A[i,j] = (accuracy on task j after training on task i) - (accuracy on task j after training on task j)
    
    Args:
        matrix: Matrix where matrix[i,j] is the difference in accuracy on task j
               after training on task i compared to accuracy right after training on task j
        
    Returns:
        One Step BWT value
    """
    K = matrix.shape[0]
    if K <= 1:
        return 0.0  # Can't compute one-step with just one task
    
    # For one-step BWT, we average the values just below the diagonal
    one_step_values = [matrix[m+1, m] for m in range(K-1)]
    return np.mean(one_step_values)

def get_seed_files(benchmark_folder: str, baseline: str) -> List[str]:
    """
    Get all seed files for a specific baseline in a benchmark folder.
    
    Args:
        benchmark_folder: Folder for the benchmark
        baseline: Baseline name
        
    Returns:
        List of file paths
    """
    pattern = os.path.join(benchmark_folder, f"{baseline}_seed*.npy")
    return glob.glob(pattern)

def process_benchmark(benchmark: str, baselines: List[str], tasks_per_benchmark: Dict[str, int]) -> Dict[str, Dict[str, Tuple[float, int]]]:
    """
    Process a benchmark and compute metrics for all baselines.
    
    Args:
        benchmark: Name of the benchmark
        baselines: List of baseline names
        tasks_per_benchmark: Dictionary mapping benchmarks to number of tasks
        
    Returns:
        Dictionary with metrics for each baseline
    """

    bwt_path = r'C:\Users\khash\OneDrive\Desktop\Research-Coding\17\processed_results\bwt_matrices'
    benchmark_folder = os.path.join(bwt_path, benchmark)
    max_tasks = tasks_per_benchmark.get(benchmark, 5)  # Default to 5 tasks if not specified
    
    results = {}
    
    for baseline in baselines:
        seed_files = get_seed_files(benchmark_folder, baseline)
        
        if len(seed_files) < 2:
            # Skip if less than 2 seeds
            results[baseline] = {"bwt": (None, 0), "one_step_bwt": (None, 0)}
            continue
            
        bwt_values = []
        one_step_bwt_values = []
        
        for seed_file in seed_files:
            matrix = load_matrix(seed_file, max_tasks)
            if matrix is not None:
                bwt_values.append(compute_bwt(matrix))
                one_step_bwt_values.append(compute_one_step_bwt(matrix))
        
        if bwt_values:
            avg_bwt = sum(bwt_values) / len(bwt_values)
            avg_one_step_bwt = sum(one_step_bwt_values) / len(one_step_bwt_values)
            results[baseline] = {
                "bwt": (avg_bwt, len(bwt_values)),
                "one_step_bwt": (avg_one_step_bwt, len(one_step_bwt_values))
            }
        else:
            results[baseline] = {"bwt": (None, 0), "one_step_bwt": (None, 0)}
    
    return results

def generate_latex_table(
    benchmarks: List[str], 
    baselines: List[str], 
    tasks_per_benchmark: Dict[str, int],
    legend_name_dict: Dict[str, str]
) -> str:
    """
    Generate a LaTeX table with BWT and One Step BWT metrics.
    
    Args:
        benchmarks: List of benchmark names
        baselines: List of baseline names
        tasks_per_benchmark: Dictionary mapping benchmarks to number of tasks
        legend_name_dict: Dictionary mapping baseline names to display names
        
    Returns:
        LaTeX code for the table
    """
    # Process all benchmarks
    all_results = {}
    for benchmark in benchmarks:
        all_results[benchmark] = process_benchmark(benchmark, baselines, tasks_per_benchmark)
    
    # Replace underscores with hyphens in benchmark names for display
    display_benchmarks = [benchmark.replace('_', '-') for benchmark in benchmarks]
    
    # Create more compact column headers
    latex_code = [
        "\\begin{table}[t]",
        "\\centering",
        "\\small",
        "\\setlength{\\tabcolsep}{4pt}",  # Reduce column spacing
        f"\\begin{{tabular}}{{l{'|cc' * len(benchmarks)}}}",
        "\\hline"
    ]
    
    # First row: benchmark names (shortened if needed)
    benchmark_headers = []
    for display_benchmark in display_benchmarks:
        # Use abbreviated names if too long
        if len(display_benchmark) > 12:
            parts = display_benchmark.split('-')
            if len(parts) > 1:
                # Use first parts and abbreviate the rest
                abbrev = '-'.join([p[:5] for p in parts])
                benchmark_headers.append(abbrev)
            else:
                benchmark_headers.append(display_benchmark[:12])
        else:
            benchmark_headers.append(display_benchmark)
    
    header_row = "& " + " & ".join([f"\\multicolumn{{2}}{{c|}}{{{header}}}" for header in benchmark_headers[:-1]]) + \
                 f" & \\multicolumn{{2}}{{c}}{{{benchmark_headers[-1]}}}" + " \\\\"
    latex_code.append(header_row)
    
    # Second row: metric names
    latex_code.append("Method & " + " & ".join(["BWT & OS-BWT" for _ in benchmarks]) + " \\\\")
    latex_code.append("\\hline")
    
    # Find max values for each metric/benchmark for bold highlighting
    max_values = {}
    for benchmark in benchmarks:
        max_values[benchmark] = {
            "bwt": float('-inf'),
            "one_step_bwt": float('-inf')
        }
        
        for baseline in baselines:
            bwt_value, _ = all_results[benchmark][baseline]["bwt"]
            one_step_value, _ = all_results[benchmark][baseline]["one_step_bwt"]
            
            if bwt_value is not None and bwt_value > max_values[benchmark]["bwt"]:
                max_values[benchmark]["bwt"] = bwt_value
            
            if one_step_value is not None and one_step_value > max_values[benchmark]["one_step_bwt"]:
                max_values[benchmark]["one_step_bwt"] = one_step_value
    
    # Add rows for each baseline
    for baseline in baselines:
        display_name = legend_name_dict.get(baseline, baseline)
        row = [display_name]
        
        for benchmark in benchmarks:
            bwt_value, bwt_count = all_results[benchmark][baseline]["bwt"]
            one_step_value, one_step_count = all_results[benchmark][baseline]["one_step_bwt"]
            
            # Format BWT value
            if bwt_value is None:
                bwt_cell = "-"
            else:
                bwt_formatted = f"{bwt_value:.3f}"
                if abs(bwt_value - max_values[benchmark]["bwt"]) < 1e-5:  # Allow for floating point imprecision
                    bwt_cell = f"\\textbf{{{bwt_formatted}}}"
                else:
                    bwt_cell = bwt_formatted
            
            # Format One Step BWT value
            if one_step_value is None:
                one_step_cell = "-"
            else:
                one_step_formatted = f"{one_step_value:.3f}"
                if abs(one_step_value - max_values[benchmark]["one_step_bwt"]) < 1e-5:  # Allow for floating point imprecision
                    one_step_cell = f"\\textbf{{{one_step_formatted}}}"
                else:
                    one_step_cell = one_step_formatted
            
            row.append(bwt_cell)
            row.append(one_step_cell)
        
        latex_code.append(" & ".join(row) + " \\\\")
    
    # Close the table
    latex_code.extend([
        "\\hline",
        "\\end{tabular}",
        "\\caption{Backwards Transfer (BWT) and One Step Backwards Transfer (OS-BWT) metrics across different benchmarks and methods. Higher values indicate better performance. Bold values represent the best performance for each metric.}",
        "\\label{tab:continual_learning_results}",
        "\\end{table}"
    ])
    
    return "\n".join(latex_code)

def generate_continual_learning_table(
    benchmarks: List[str], 
    baselines: List[str], 
    tasks_per_benchmark: Dict[str, int] = None,
    legend_name_dict: Dict[str, str] = None,
    output_folder: str = "output"
) -> str:
    """
    Main function to generate the LaTeX table for continual learning metrics.
    
    Args:
        benchmarks: List of benchmark names to include
        baselines: List of baseline methods to include
        tasks_per_benchmark: Dictionary mapping benchmarks to number of tasks
        legend_name_dict: Dictionary mapping baseline names to display names
        output_folder: Folder to save the LaTeX file
        
    Returns:
        LaTeX code for the table
    """
    # Default values if not provided
    if tasks_per_benchmark is None:
        tasks_per_benchmark = {
            'random_MNIST': 5,
            'random_label_cifar10': 5,
            'shuffle_cifar10': 5,
            'permuted_MNIST': 10,
            'continual_cifar100': 20,
            'continual_imagenet': 10
        }
    
    if legend_name_dict is None:
        legend_name_dict = {
            'Base': 'Base', 
            'CBP': 'CBP', 
            'CReLU': 'CReLU',
            'DeepF': 'DeepF', 
            'EWC': 'EWC', 
            'L2': 'L2', 
            'L2Init': 'L2Init', 
            'LayerNorm': 'LayerNorm',
            'NeuroSync': 'NeuroSync', 
            'PReLU': 'PReLU', 
            'ReDo': 'ReDo', 
            'Scratch': 'Scratch'
        }
    
    # Generate LaTeX table
    latex_content = generate_latex_table(benchmarks, baselines, tasks_per_benchmark, legend_name_dict)
    
    # Create output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)
    
    # Save to file
    output_file = os.path.join(output_folder, "bwt_results.tex")
    with open(output_file, "w") as f:
        f.write(latex_content)
    
    print(f"LaTeX table saved to {output_file}")
    
    return latex_content

# Example usage:
if __name__ == "__main__":
    # benchmarks = ['random_MNIST', 'random_label_cifar10', 'shuffle_cifar10', 'permuted_MNIST', 'continual_cifar100', 'continual_imagenet']
    # baselines = ['Base', 'CBP', 'CReLU', 'DeepF', 'EWC', 'L2', 'L2Init', 'LayerNorm', 'NeuroSync', 'PReLU', 'ReDo', 'L2InitPlusEWC']
    number_of_rows = {
        'random_MNIST' : 30,
        'random_label_cifar10' : 30,
        'shuffle_cifar10' : 30,
        'permuted_MNIST': 25,
        'continual_cifar100' : 20,
        'continual_imagenet': 100,    
    }
    

    benchmarks = [ 'random_MNIST',
        'random_label_cifar10',
        'shuffle_cifar10','permuted_MNIST', 'continual_cifar100', 'continual_imagenet']
    baselines = ['Base', 'CBP', 'CReLU', 'EWC', 'NeuroSync', 'PReLU', 'ReDo', 'L2InitPlusEWC', 'MAML']
    legend_name_dict = {'Base': 'Base', 'CBP': 'CBP', 'CReLU': 'CReLU',
                    'DeepF': 'DeepF', 'EWC' : 'EWC', 'L2': 'L2', 'L2Init': 'L2Init', 'LayerNorm': 'LayerNorm',
                    'NeuroSync': 'NeuroSync', 'PReLU': 'PReLU', 'ReDo': 'ReDo', 'Scratch': 'Scratch',
                    "L2InitPlusEWC" : 'L2Init + EWC'}
    output_folder = './processed_results'
    latex_table = generate_continual_learning_table(benchmarks, baselines, tasks_per_benchmark= MAX_STEPS,
                                                    legend_name_dict= legend_name_dict, output_folder= output_folder)
    print(latex_table)