import json
from pathlib import Path
import re
from typing import Dict, List, Tuple
from collections import defaultdict
import numpy as np


def get_final_metrics(metrics_file: Path, metric_type: str = "l2") -> float:
    """Extract the final step L2 distance to dataset or runtime from metrics file.

    Args:
        metrics_file: Path to metrics.json file
        metric_type: Either 'l2' for L2 distance or 'time' for runtime metrics

    Returns:
        float: The metric value, or None if the metric is not available
    """
    with open(metrics_file, "r") as f:
        data = json.load(f)

    if metric_type == "l2":
        # Get the last timestep data
        final_timestep = data["per_timestep"][-1]
        return final_timestep["l2_distance_to_dataset"]
    elif metric_type == "time":
        # Get runtime data if available
        if "runtime" in data and "total_sample_time" in data["runtime"]:
            return data["runtime"]["total_sample_time"]
        return None
    else:
        raise ValueError("metric_type must be either 'l2' or 'time'")


def get_dataset_from_path(path: Path) -> str:
    """Extract dataset name from experiment path."""
    # Convert path to string and look for dataset names
    path_str = str(path)
    datasets = {
        "fashion_mnist": "Fashion MNIST",
        "cifar10": "CIFAR10",
        "celeba_hq": "CelebA-HQ",
        "afhq": "AFHQv2",
        "mnist": "MNIST",
    }

    for key, value in datasets.items():
        if key in path_str.lower():
            return value
    return None


def get_method_from_path(path: Path) -> str:
    """Extract method name from experiment path."""
    path_str = str(path)
    if "wiener" in path_str.lower():
        return "Wiener (linear)"
    elif "kamb" in path_str.lower():
        return "Kamb \& Ganguli \\cite{kamb2024analytic}"
    elif "niedoba" in path_str.lower():
        return "Niedoba et al.\\cite{niedoba2024towards}"
    elif "ours" in path_str.lower():
        return "\\textbf{Ours}"
    elif "optimal" in path_str.lower():
        return "Optimal Denoiser"
    elif "unet" in path_str.lower():
        return "Another DDPM"
    return None


def generate_latex_table(
    results_dir: str = "experiment_results", mode: str = "avg", metric_type: str = "l2"
):
    """Generate LaTeX table from experiment results showing L2 distances or runtime metrics.

    Args:
        results_dir: Directory containing experiment results
        mode: One of 'avg' (averages only), 'std' (standard deviations only),
              'as' (average ± standard deviation), or 'count' (number of valid runs)
        metric_type: Either 'l2' for L2 distance or 'time' for runtime metrics
    """
    if mode not in ["avg", "std", "as", "count"]:
        raise ValueError("Mode must be one of: 'avg', 'std', 'as', 'count'")
    if metric_type not in ["l2", "time"]:
        raise ValueError("metric_type must be either 'l2' or 'time'")

    results_path = Path(results_dir)

    # Dictionary to store results by method and dataset
    results = defaultdict(lambda: defaultdict(list))
    run_counts = defaultdict(lambda: defaultdict(int))

    # Traverse all experiment directories
    for exp_dir in results_path.iterdir():
        if not exp_dir.is_dir():
            continue

        metrics_file = exp_dir / "metrics.json"
        if not metrics_file.exists():
            continue

        dataset = get_dataset_from_path(exp_dir)
        method = get_method_from_path(exp_dir)

        if dataset is None or method is None:
            continue

        metric_value = get_final_metrics(metrics_file, metric_type)
        if metric_value is not None:  # Only add if metric is available
            results[method][dataset].append(metric_value)
            run_counts[method][dataset] += 1

    # Calculate minimum number of runs across all methods and datasets
    min_runs = float("inf")
    for method in run_counts:
        for dataset in run_counts[method]:
            min_runs = min(min_runs, run_counts[method][dataset])

    # Calculate total samples (min_runs * 6)
    total_samples = min_runs * 6

    # Define the order of methods and datasets
    method_order = [
        "Optimal Denoiser",
        "Wiener (linear)",
        "Kamb \& Ganguli \\cite{kamb2024analytic}",
        "Niedoba et al.\\cite{niedoba2024towards}",
        "\\textbf{Ours}",
        "Another DDPM",
    ]

    dataset_order = ["CIFAR10", "CelebA-HQ", "AFHQv2", "MNIST", "Fashion MNIST"]

    # Generate LaTeX table
    latex = []
    latex.append("\\begin{table}[h]")
    latex.append("    \\footnotesize")

    # Update caption based on mode and metric type
    mode_descriptions = {
        "avg": "averaged",
        "std": "standard deviations",
        "as": "averaged with standard deviations",
        "count": "number of valid runs",
    }

    metric_descriptions = {
        "l2": "L2 distance to the dataset",
        "time": "total sampling time (seconds)",
    }

    latex.append(
        f"  \\caption{{We demonstrate {'if the analytical models are capable of producing novel samples by displaying the average L2 distance to the dataset' if metric_type == 'l2' else 'the computational efficiency of each method by displaying the total sampling time'} for each of the baselines. Results show {mode_descriptions[mode]} over {total_samples} samples.}}"
    )
    latex.append(
        "  \\label{tab:"
        + ("l2-distance" if metric_type == "l2" else "runtime")
        + "-comparison}"
    )
    latex.append("  \\centering")
    latex.append("  \\setlength{\\tabcolsep}{4pt}")
    latex.append("  \\begin{tabular}{lccccc}")
    latex.append("    \\toprule")

    # Header with proper spacing
    header = "    Method & " + " & ".join(dataset_order) + " \\\\"
    latex.append(header)
    latex.append("    \\midrule")

    # Data rows with proper spacing and alignment
    for method in method_order:
        if method not in results:
            continue

        # Format method name with proper spacing
        method_name = method.ljust(40)  # Adjust width as needed

        row = [method_name]
        for dataset in dataset_order:
            if dataset in results[method] and len(results[method][dataset]) > 0:
                # Calculate statistics
                values = np.array(results[method][dataset])

                if mode == "count":
                    value_str = str(6 * len(values))
                elif metric_type == "time":
                    # Format time values with 2 decimal places
                    if mode == "avg":
                        value_str = f"{np.min(values):.2f}" # calculate min because of the delays
                    elif mode == "std":
                        value_str = f"$\\pm$ {np.std(values):.2f}"
                    else:  # mode == "as"
                        value_str = f"{np.min(values):.2f} $\\pm$ {np.std(values):.2f}"
                else:
                    # Format L2 values with 3 decimal places
                    if mode == "avg":
                        value_str = f"{np.mean(values):.3f}"
                    elif mode == "std":
                        value_str = f"$\\pm${np.std(values):.3f}"
                    else:  # mode == "as"
                        value_str = f"{np.mean(values):.3f} $\\pm$ {np.std(values):.3f}"

                row.append(value_str)
            else:
                row.append("N/A")

        latex.append("    " + " & ".join(row) + " \\\\")
        if method == "\\textbf{Ours}" or method == "Optimal Denoiser":
            latex.append("    \\midrule")

    # End table
    latex.append("    \\bottomrule")
    latex.append("  \\end{tabular}")
    latex.append("\\end{table}")

    # Write to file with mode and metric type in filename
    output_file = f"paper/{metric_type}_table_{mode}.tex"
    with open(output_file, "w") as f:
        f.write("\n".join(latex))


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Generate LaTeX table showing L2 distances or runtime metrics"
    )
    parser.add_argument(
        "--mode",
        type=str,
        choices=["avg", "std", "as", "count"],
        default="as",
        help="Display mode: 'avg' for averages, 'std' for standard deviations, 'as' for average ± std, 'count' for number of valid runs",
    )
    parser.add_argument(
        "--metric",
        type=str,
        choices=["l2", "time"],
        default="l2",
        help="Metric type: 'l2' for L2 distance, 'time' for runtime metrics",
    )
    args = parser.parse_args()
    generate_latex_table(mode=args.mode, metric_type=args.metric)
