import argparse
import json
from pathlib import Path
from typing import Dict, List, Tuple
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# Define consistent colors for each method
METHOD_COLORS = {
    "Optimal Denoiser": "rgba(31, 119, 180, 0.25)",  # blue with transparency
    "Wiener (linear)": "rgba(255, 127, 14, 0.25)",  # orange with transparency
    "Kamb & Ganguli": "rgba(44, 160, 44, 0.25)",  # green with transparency
    "Niedoba et al.": "rgba(214, 39, 40, 0.25)",  # red with transparency
    "Ours": "rgba(148, 103, 189, 0.5)",  # purple with transparency
    "Another DDPM": "rgba(140, 86, 75, 0.25)",  # brown with transparency
}

# Define solid colors for lines
METHOD_LINE_COLORS = {
    "Optimal Denoiser": "rgba(31, 119, 180, 0.5)",  # solid blue
    "Wiener (linear)": "rgba(255, 127, 14, 0.5)",  # solid orange
    "Kamb & Ganguli": "rgba(44, 160, 44, 0.5)",  # solid green
    "Niedoba et al.": "rgba(214, 39, 40, 0.5)",  # solid red
    "Ours": "rgba(148, 103, 189, 1)",  # solid purple
    "Another DDPM": "rgba(140, 86, 75, 0.5)",  # solid brown
}

# Define consistent line styles
METHOD_LINE_STYLES = {
    "Optimal Denoiser": "dot",
    "Wiener (linear)": "dot",
    "Kamb & Ganguli": "dot",
    "Niedoba et al.": "dot",
    "Ours": "solid",
    "Another DDPM": "dot",
}


def get_dataset_from_path(path: Path) -> str:
    """Extract dataset name from experiment path."""
    path_str = str(path)
    datasets = {
        "fashion_mnist": "Fashion MNIST",
        "cifar10": "CIFAR10",
        "celeba_hq": "CelebA-HQ",
        "afhq": "AFHQv2",
        "mnist": "MNIST",
    }

    for key, value in datasets.items():
        if key in path_str.lower():
            return value
    return None


def get_method_from_path(path: Path) -> str:
    """Extract method name from experiment path."""
    path_str = str(path)
    if "wiener" in path_str.lower():
        return "Wiener (linear)"
    elif "kamb" in path_str.lower():
        return "Kamb & Ganguli"
    elif "niedoba" in path_str.lower():
        return "Niedoba et al."
    elif "ours" in path_str.lower():
        return "Ours"
    elif "optimal" in path_str.lower():
        return "Optimal Denoiser"
    elif "unet" in path_str.lower():
        return "Another DDPM"
    return None


def load_metrics(results_dir: Path) -> Dict[str, Dict[str, List[float]]]:
    """Load metrics from all experiment directories and calculate averages across runs."""
    # First collect all runs
    all_runs = {}

    for exp_dir in results_dir.iterdir():
        if not exp_dir.is_dir():
            continue

        metrics_file = exp_dir / "metrics.json"
        if not metrics_file.exists():
            continue

        dataset = get_dataset_from_path(exp_dir)
        method = get_method_from_path(exp_dir)

        if dataset is None or method is None:
            continue

        with open(metrics_file, "r") as f:
            data = json.load(f)

        if dataset not in all_runs:
            all_runs[dataset] = {}

        # Extract timesteps and metric values
        timesteps = [step["timestep"] for step in data["per_timestep"]]
        for metric in data["per_timestep"][0].keys():
            if metric != "timestep":
                if metric not in all_runs[dataset]:
                    all_runs[dataset][metric] = {}
                if method not in all_runs[dataset][metric]:
                    all_runs[dataset][metric][method] = []

                # Store this run's data
                all_runs[dataset][metric][method].append(
                    {
                        "timesteps": timesteps,
                        "values": [step[metric] for step in data["per_timestep"]],
                    }
                )

    # Now calculate averages and std across runs
    metrics_by_dataset = {}

    for dataset in all_runs:
        metrics_by_dataset[dataset] = {}
        for metric in all_runs[dataset]:
            metrics_by_dataset[dataset][metric] = {}
            for method in all_runs[dataset][metric]:
                runs = all_runs[dataset][metric][method]

                # Ensure all runs have the same timesteps
                timesteps = runs[0]["timesteps"]
                if not all(run["timesteps"] == timesteps for run in runs):
                    print(
                        f"Warning: Inconsistent timesteps for {dataset}/{metric}/{method}"
                    )
                    continue

                # Calculate average values and std across runs
                values = np.mean([run["values"] for run in runs], axis=0)
                std_values = np.std([run["values"] for run in runs], axis=0)

                metrics_by_dataset[dataset][metric][method] = {
                    "timesteps": timesteps,
                    "values": values.tolist(),
                    "std": std_values.tolist(),
                }

    return metrics_by_dataset


def plot_metric(
    metrics_by_dataset: Dict[str, Dict[str, List[float]]],
    metric_name: str,
    output_dir: Path,
    width: int = 1200,
    height: int = 800,
    use_log_scale: bool = False,
    show_std: bool = False,
):
    """Create a plot for a specific metric across all datasets."""
    # Create subplots for each dataset
    datasets = list(metrics_by_dataset.keys())
    fig = make_subplots(
        rows=1,
        cols=len(datasets),
        subplot_titles=[f"{dataset}" for dataset in datasets],
        horizontal_spacing=0.02,  # Reduced from 0.1 to 0.02
    )

    # Add traces for each method in each dataset
    for i, dataset in enumerate(datasets, 1):
        if metric_name in metrics_by_dataset[dataset]:
            # Sort methods to ensure "Ours" is plotted last
            methods = sorted(
                metrics_by_dataset[dataset][metric_name].keys(),
                key=lambda x: 1 if x == "Ours" else 0
            )
            
            for method in methods:
                data = metrics_by_dataset[dataset][metric_name][method]
                # Get values and timesteps
                values = np.array(data["values"])
                timesteps = np.array(data["timesteps"])
                
                # If showing std, add the std region first
                if show_std and "std" in data:
                    std_values = np.array(data["std"])
                    fig.add_trace(
                        go.Scatter(
                            x=np.concatenate([timesteps, timesteps[::-1]]),
                            y=np.concatenate([values + std_values, (values - std_values)[::-1]]),
                            fill="toself",
                            fillcolor=METHOD_COLORS.get(method, "rgba(0, 0, 0, 0.2)"),
                            line=dict(width=0),
                            showlegend=False,
                            hoverinfo="skip",
                        ),
                        row=1,
                        col=i,
                    )

                # Add main line
                fig.add_trace(
                    go.Scatter(
                        x=timesteps,
                        y=values,
                        name=method,
                        line=dict(
                            color=METHOD_LINE_COLORS.get(method, "#000000"),
                            width=3 if method == "Ours" else 1,  # Thicker line for "Ours"
                            dash=METHOD_LINE_STYLES.get(method, "solid"),
                        ),
                        showlegend=(i == 1),  # Only show legend for first subplot
                    ),
                    row=1,
                    col=i,
                )

    # Update layout
    fig.update_layout(
        width=width,
        height=height,
        font=dict(family="Times New Roman", size=11),
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=1.05,
            font=dict(size=11),
            bgcolor="rgba(255, 255, 255, 0.8)",
        ),
        margin=dict(l=20, r=20, t=20, b=20),  # Reduced margins
        plot_bgcolor="white",
        paper_bgcolor="white",
        showlegend=True,
    )

    # Update axes labels and styling
    for i in range(1, len(datasets) + 1):
        # Get the range of timesteps for this subplot
        timesteps = []
        for method_data in metrics_by_dataset[datasets[i - 1]][metric_name].values():
            timesteps.extend(method_data["timesteps"])
        timesteps = sorted(set(timesteps))

        # Create tick values and labels
        tickvals = timesteps
        ticktext = [f"{1000 - t}" for t in timesteps]

        fig.update_xaxes(
            row=1,
            col=i,
            showgrid=True,
            gridwidth=1,
            gridcolor="lightgray",
            zeroline=False,
            showline=True,
            linewidth=1,
            linecolor="black",
            title="Denoising step t",  # Add x-axis label
            tickvals=tickvals,
            ticktext=ticktext,
        )
        fig.update_yaxes(
            row=1,
            col=i,
            showgrid=True,
            gridwidth=1,
            gridcolor="lightgray",
            zeroline=False,
            showline=True,
            linewidth=1,
            linecolor="black",
            type="log" if use_log_scale else "linear",
        )

    # Save figure
    output_dir.mkdir(parents=True, exist_ok=True)
    fig.write_image(output_dir / f"{metric_name}.pdf")


def main():
    parser = argparse.ArgumentParser(description="Plot metrics over time")
    parser.add_argument(
        "--results_dir",
        type=str,
        default="experiment_results",
        help="Directory containing experiment results",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1200,
        help="Figure width in pixels",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=200,
        help="Figure height in pixels",
    )
    parser.add_argument(
        "--show_std",
        action="store_true",
        help="Show standard deviation as translucent background",
    )
    args = parser.parse_args()

    # Load metrics
    metrics_by_dataset = load_metrics(Path(args.results_dir))

    # Create output directory
    output_dir = Path("paper/figures")
    output_dir.mkdir(parents=True, exist_ok=True)

    for metric, use_log in [
        ("r2_score_vs_unet", False),
        ("r2_score_vs_unet_single", True),
        ("r2_score_vs_unet_single_eps", True),
        ("l2_distance_to_dataset", False),
        ("mse_trajectory_vs_unet", False),
        ("mse_trajectory_vs_unet_single", True),
        ("mse_trajectory_vs_unet_single_eps", True),
    ]:
        # Plot metric
        plot_metric(
            metrics_by_dataset,
            metric,
            output_dir,
            width=args.width,
            height=args.height,
            use_log_scale=use_log,
            show_std=args.show_std,
        )


if __name__ == "__main__":
    main()
