import json
from pathlib import Path
import re
from typing import Dict, List, Tuple
from collections import defaultdict
import numpy as np

def get_final_metrics(metrics_file: Path) -> Tuple[float, float]:
    """Extract the final step r2 score and MSE from metrics file."""
    with open(metrics_file, "r") as f:
        data = json.load(f)

    # Get the last timestep data
    final_timestep = data["per_timestep"][-1]
    r2 = 1 - final_timestep["r2_score_vs_unet"]
    mse = final_timestep["mse_trajectory_vs_unet"]
    return r2, mse


def get_dataset_from_path(path: Path) -> str:
    """Extract dataset name from experiment path."""
    # Convert path to string and look for dataset names
    path_str = str(path)
    datasets = {
        "fashion_mnist": "Fashion MNIST",
        "cifar10": "CIFAR10",
        "celeba_hq": "CelebA-HQ",
        "afhq": "AFHQv2",
        "mnist": "MNIST",
    }

    for key, value in datasets.items():
        if key in path_str.lower():
            return value
    return None


def get_method_from_path(path: Path) -> str:
    """Extract method name from experiment path."""
    path_str = str(path)
    if "wiener" in path_str.lower():
        return "Wiener (linear)"
    elif "kamb" in path_str.lower():
        return "Kamb \& Ganguli \\cite{kamb2024analytic}"
    elif "niedoba" in path_str.lower():
        return "Niedoba et al.\\cite{niedoba2024towards}"
    elif "ours" in path_str.lower():
        return "\\textbf{Ours}"
    elif "optimal" in path_str.lower():
        return "Optimal Denoiser"
    elif "unet" in path_str.lower():
        return "Another DDPM"
    return None


def generate_latex_table(results_dir: str = "experiment_results", mode: str = "avg"):
    """Generate LaTeX table from experiment results.

    Args:
        results_dir: Directory containing experiment results
        mode: One of 'avg' (averages only), 'std' (standard deviations only),
              or 'as' (average ± standard deviation)
    """
    if mode not in ["avg", "std", "as"]:
        raise ValueError("Mode must be one of: 'avg', 'std', 'as'")

    results_path = Path(results_dir)

    # Dictionary to store results by method and dataset
    results = defaultdict(lambda: defaultdict(list))
    run_counts = defaultdict(lambda: defaultdict(int))

    # Traverse all experiment directories
    for exp_dir in results_path.iterdir():
        if not exp_dir.is_dir():
            continue

        metrics_file = exp_dir / "metrics.json"
        if not metrics_file.exists():
            continue

        dataset = get_dataset_from_path(exp_dir)
        method = get_method_from_path(exp_dir)

        if dataset is None or method is None:
            continue

        r2, mse = get_final_metrics(metrics_file)
        results[method][dataset].append((r2, mse))
        run_counts[method][dataset] += 1

    # Calculate minimum number of runs across all methods and datasets
    min_runs = float("inf")
    for method in run_counts:
        for dataset in run_counts[method]:
            min_runs = min(min_runs, run_counts[method][dataset])

    # Calculate total samples (min_runs * 6)
    total_samples = min_runs * 6

    # Define the order of methods and datasets
    method_order = [
        "Optimal Denoiser",
        "Wiener (linear)",
        "Kamb \& Ganguli \\cite{kamb2024analytic}",
        "Niedoba et al.\\cite{niedoba2024towards}",
        "\\textbf{Ours}",
        "Another DDPM",
    ]

    dataset_order = ["CIFAR10", "CelebA-HQ", "AFHQv2", "MNIST", "Fashion MNIST"]

    # Generate LaTeX table
    latex = []
    latex.append("\\begin{table}[h]")
    latex.append("    \\footnotesize")

    # Update caption based on mode
    mode_descriptions = {
        "avg": "averaged",
        "std": "standard deviations",
        "as": "averaged with standard deviations",
    }
    latex.append(
        f"  \\caption{{We compare how well our analytical model and other baselines explain a trained diffusion model across datasets using $r^2$ and MSE metrics. Results are {mode_descriptions[mode]} over {total_samples} samples.}}"
    )
    latex.append("  \\label{tab:metrics-comparison}")
    latex.append("  \\centering")
    latex.append("  \\setlength{\\tabcolsep}{4pt}")
    latex.append("  \\begin{tabular}{lcc cc cc cc cc}")
    latex.append("    \\toprule")

    # Header with proper spacing
    header = (
        "    & "
        + " & ".join(
            [f"\\multicolumn{{2}}{{c}}{{{dataset}}}" for dataset in dataset_order]
        )
        + " \\\\"
    )
    latex.append(header)

    # Midrules with proper column ranges
    latex.append("    \\cmidrule(lr){2-3}")
    latex.append("    \\cmidrule(lr){4-5}")
    latex.append("    \\cmidrule(lr){6-7}")
    latex.append("    \\cmidrule(lr){8-9}")
    latex.append("    \\cmidrule(lr){10-11}")

    # Column headers with proper spacing
    col_headers = (
        "    Method & "
        + " & ".join(
            [f"\\(r^2 \\uparrow\\) & MSE\\(\\downarrow\\)" for _ in dataset_order]
        )
        + " \\\\"
    )
    latex.append(col_headers)
    latex.append("    \\midrule")

    # Data rows with proper spacing and alignment
    for method in method_order:
        if method not in results:
            continue

        # Format method name with proper spacing
        method_name = method.ljust(40)  # Adjust width as needed

        row = [method_name]
        for dataset in dataset_order:
            if dataset in results[method]:
                # Calculate statistics
                r2_values, mse_values = zip(*results[method][dataset])
                r2_values = np.array(r2_values)
                mse_values = np.array(mse_values)

                if mode == "avg":
                    r2_str = f"{np.mean(r2_values):.3f}"
                    mse_str = f"{np.mean(mse_values):.3f}"
                elif mode == "std":
                    r2_str = f"$\\pm$ {np.std(r2_values):.3f}"
                    mse_str = f"$\\pm$ {np.std(mse_values):.3f}"
                else:  # mode == "as"
                    r2_str = f"{np.mean(r2_values):.3f} $\\pm$ {np.std(r2_values):.3f}"
                    mse_str = (
                        f"{np.mean(mse_values):.3f} $\\pm$ {np.std(mse_values):.3f}"
                    )

                row.extend([r2_str, mse_str])
            else:
                row.extend(["N/A", "N/A"])

        latex.append("    " + " & ".join(row) + " \\\\")
        if method == "\\textbf{Ours}" or method == "Optimal Denoiser":
            latex.append("    \\midrule")

    # End table
    latex.append("    \\bottomrule")
    latex.append("  \\end{tabular}")
    latex.append("\\end{table}")

    # Write to file with mode in filename
    output_file = f"paper/main_table_appendix_{mode}.tex"
    with open(output_file, "w") as f:
        f.write("\n".join(latex))


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Generate LaTeX table from experiment results"
    )
    parser.add_argument(
        "--mode",
        type=str,
        choices=["avg", "std", "as"],
        default="avg",
        help="Display mode: 'avg' for averages, 'std' for standard deviations, 'as' for average ± std",
    )
    args = parser.parse_args()
    generate_latex_table(mode=args.mode)
