#!/usr/bin/env python3
"""
CLI for analysis and visualization of results.
"""

import argparse
import datetime
import json
import os
import shutil
import sys
from pathlib import Path

from rtrank.analysis.plot_combined_dataset_metrics import (
    run_analysis as run_kendall_analysis,
)
from rtrank.analysis.plot_combined_metrics import (
    run_analysis as run_combined_metrics_analysis,
)
from rtrank.analysis.plot_datasize import run_analysis as run_datasize_analysis

LEARNERS = [
    "bt",
    "rt_regression",
    "rt_regression_perm",
    "rt_rank",
    "rt_rank_pooled",
    "rt_rank_perm",
]
REPO_ROOT = Path(__file__).parents[3]
OUTPUT_DIR = REPO_ROOT / "paper" / "figures"
PATHS_FILE = OUTPUT_DIR / "experiment_paths.json"
EXPERIMENT_DIR = REPO_ROOT / "outputs" / "experiment_runner"

DATASETS = {
    "with_variability": [
        {"name": "Deterministic", "config_name": "deterministic_all"},
        {"name": "Drift Diffusion", "config_name": "drift_diffusion"},
        {"name": "Stochastic", "config_name": "stochastic"},
    ],
    "no_variability": [
        {
            "name": "Deterministic No Variability",
            "config_name": "deterministic_all_no_variability",
        },
        {
            "name": "Drift Diffusion No Variability",
            "config_name": "drift_diffusion_no_variability",
        },
        {
            "name": "Stochastic No Variability",
            "config_name": "stochastic_no_variability",
        },
    ],
}


def load_experiment_paths():
    """Load experiment paths from JSON file."""
    if os.path.exists(PATHS_FILE):
        with open(PATHS_FILE, "r") as f:
            return json.load(f)
    return {}


def save_experiment_paths(paths):
    """Save experiment paths to JSON file."""
    os.makedirs(PATHS_FILE.parent, exist_ok=True)
    with open(PATHS_FILE, "w") as f:
        json.dump(paths, f, indent=2)


def find_latest_experiments():
    """Find the latest experiment directory for each dataset by parsing metadata.json files."""
    latest_by_dataset = {}

    dataset_metadata_to_config = {
        "Deterministic": "deterministic_all",
        "Stochastic": "stochastic",
        "Drift Diffusion": "drift_diffusion",
        "Deterministic No Variability": "deterministic_all_no_variability",
        "Stochastic No Variability": "stochastic_no_variability",
        "Drift Diffusion No Variability": "drift_diffusion_no_variability",
    }

    # Walk through all experiment directories
    for root, _, files in os.walk(EXPERIMENT_DIR):
        if "metadata.json" in files:
            metadata_path = Path(root) / "metadata.json"

            with open(metadata_path, "r") as f:
                metadata = json.load(f)

            dataset_name = metadata.get("dataset")
            if not dataset_name or dataset_name not in dataset_metadata_to_config:
                continue

            config_dataset = dataset_metadata_to_config[dataset_name]
            timestamp_str = metadata["timestamp"]
            timestamp_dt = datetime.datetime.strptime(
                timestamp_str, "%Y-%m-%d %H:%M:%S"
            )

            if (
                config_dataset not in latest_by_dataset
                or timestamp_dt.timestamp()
                > latest_by_dataset[config_dataset]["timestamp_dt"].timestamp()
            ):
                latest_by_dataset[config_dataset] = {
                    "path": str(Path(root)),
                    "timestamp_dt": timestamp_dt,
                }

    paths = {}
    for dataset, info in latest_by_dataset.items():
        paths[dataset] = info["path"]

    return paths


def find_command():
    """Find latest experiment directories and update paths file"""
    print("Finding latest experiment directories...")
    paths = find_latest_experiments()
    save_experiment_paths(paths)
    print(f"Updated experiment paths in {PATHS_FILE}")

    print("Found latest experiments:")
    for dataset, path in paths.items():
        path_obj = Path(path)
        print(f"- {dataset}: {path_obj.name}")


def confusion_matrices_command():
    """Copy confusion matrix plots to results directory"""
    paths = load_experiment_paths()
    if not paths:
        print("No experiment paths found. Run 'find' command first.")
        return

    print("Copying confusion matrix plots...")

    for var_type in [
        "deterministic",
        "ddm",
        "stochastic",
        "deterministic_no_variability",
        "ddm_no_variability",
        "stochastic_no_variability",
    ]:
        os.makedirs(OUTPUT_DIR / var_type, exist_ok=True)

    config_to_result_dir = {
        "deterministic_all": "deterministic",
        "drift_diffusion": "ddm",
        "stochastic": "stochastic",
        "deterministic_all_no_variability": "deterministic_no_variability",
        "drift_diffusion_no_variability": "ddm_no_variability",
        "stochastic_no_variability": "stochastic_no_variability",
    }

    for config, exp_path in paths.items():
        if config in config_to_result_dir:
            result_dir = config_to_result_dir[config]
            plot_file = "test_confusion_matrix_confusion_matrix.png"
            src_path = Path(exp_path) / plot_file
            dst_path = OUTPUT_DIR / result_dir / plot_file
            shutil.copy2(src_path, dst_path)


def run_combined_metrics(result_dirs, learners, output_dir, output_suffix):
    """Run the combined metrics analysis."""
    metrics = ["pearson_distance_correlation", "choice_accuracy"]

    # Call the analysis function directly
    run_combined_metrics_analysis(
        result_dirs=result_dirs,
        learners=learners,
        output_dir=output_dir,
        metrics=metrics,
        output_suffix=output_suffix,
    )


def boxplots_command(variability):
    """Generate combined metric boxplots"""
    paths = load_experiment_paths()
    if not paths:
        print("No experiment paths found. Run 'find' command first.")
        return
    if variability in ["with", "both"]:
        print("Generating boxplots with variability...")
        with_var_dirs = [
            paths.get("deterministic_all", ""),
            paths.get("drift_diffusion", ""),
            paths.get("stochastic", ""),
        ]

        if "" in with_var_dirs:
            missing = [
                ds
                for i, ds in enumerate(
                    ["deterministic_all", "drift_diffusion", "stochastic"]
                )
                if with_var_dirs[i] == ""
            ]
            print(
                f"Warning: Missing paths for {', '.join(missing)}. Some plots may be incomplete."
            )

        run_combined_metrics(with_var_dirs, LEARNERS, OUTPUT_DIR, "with_variability")

    if variability in ["without", "both"]:
        print("Generating boxplots without variability...")
        no_var_dirs = [
            paths.get("deterministic_all_no_variability", ""),
            paths.get("drift_diffusion_no_variability", ""),
            paths.get("stochastic_no_variability", ""),
        ]

        # Check if we have all required paths
        if "" in no_var_dirs:
            missing = [
                ds
                for i, ds in enumerate(
                    [
                        "deterministic_all_no_variability",
                        "drift_diffusion_no_variability",
                        "stochastic_no_variability",
                    ]
                )
                if no_var_dirs[i] == ""
            ]
            print(
                f"Warning: Missing paths for {', '.join(missing)}. Some plots may be incomplete."
            )

        run_combined_metrics(no_var_dirs, LEARNERS, OUTPUT_DIR, "no_variability")


def run_datasize_plots(result_dirs, learners, output_dir, output_filename):
    """Run the dataset size vs metrics analysis."""
    metrics = ["pearson_distance_correlation", "choice_accuracy"]

    # Call the analysis function directly
    run_datasize_analysis(
        result_dirs=result_dirs,
        learners=learners,
        output_dir=output_dir,
        metrics=metrics,
        output_filename=output_filename,
    )


def datasize_command(variability):
    """Generate dataset size vs metrics plots"""
    paths = load_experiment_paths()
    if not paths:
        print("No experiment paths found. Run 'find' command first.")
        return
    if variability in ["with", "both"]:
        print("Generating dataset size plots with variability...")
        with_var_dirs = [
            paths.get("deterministic_all", ""),
            paths.get("drift_diffusion", ""),
            paths.get("stochastic", ""),
        ]

        run_datasize_plots(
            with_var_dirs,
            LEARNERS,
            OUTPUT_DIR,
            "metrics_vs_datasize_combined_with_variability",
        )

    if variability in ["without", "both"]:
        print("Generating dataset size plots without variability...")
        no_var_dirs = [
            paths.get("deterministic_all_no_variability", ""),
            paths.get("drift_diffusion_no_variability", ""),
            paths.get("stochastic_no_variability", ""),
        ]

        run_datasize_plots(
            no_var_dirs,
            LEARNERS,
            OUTPUT_DIR,
            "metrics_vs_datasize_combined_no_variability",
        )


def run_kendall_tau(
    result_dirs, output_dir, output_suffix, x_range_min=0, x_range_max=1
):
    """Run the Kendall tau analysis."""
    metrics = ["test_abs_logit_diff_rt_kendall_tau"]

    # Call the analysis function directly
    run_kendall_analysis(
        result_dirs=result_dirs,
        metrics=metrics,
        output_dir=output_dir,
        output_suffix=output_suffix,
        x_range_min=x_range_min,
        x_range_max=x_range_max,
    )


def kendall_command(variability):
    """Generate Kendall tau histograms"""
    paths = load_experiment_paths()
    if not paths:
        print("No experiment paths found. Run 'find' command first.")
        return
    if variability in ["with", "both"]:
        print("Generating Kendall tau histograms with variability...")
        with_var_dirs = [
            paths.get("deterministic_all", ""),
            paths.get("drift_diffusion", ""),
            paths.get("stochastic", ""),
        ]

        run_kendall_tau(with_var_dirs, OUTPUT_DIR, "with_variability")

    if variability in ["without", "both"]:
        print("Generating Kendall tau histograms without variability...")
        no_var_dirs = [
            paths.get("deterministic_all_no_variability", ""),
            paths.get("drift_diffusion_no_variability", ""),
            paths.get("stochastic_no_variability", ""),
        ]

        run_kendall_tau(no_var_dirs, OUTPUT_DIR, "no_variability")


def all_command():
    """Run all commands in sequence"""
    find_command()
    confusion_matrices_command()

    # Call commands with explicit keyword arguments
    boxplots_command(variability="both")
    datasize_command(variability="both")
    kendall_command(variability="both")
    print("All analyses completed!")


def main():
    parser = argparse.ArgumentParser(description="RtRank Analysis Tool")
    subparsers = parser.add_subparsers(dest="command", help="Analysis command to run")

    subparsers.add_parser("find", help="Find latest experiment directories")

    subparsers.add_parser("confusion-matrices", help="Copy confusion matrix plots")

    box_parser = subparsers.add_parser(
        "boxplots", help="Generate combined metric boxplots"
    )
    box_parser.add_argument(
        "--variability",
        choices=["with", "without", "both"],
        default="both",
        help="Which variability condition to generate plots for",
    )

    ds_parser = subparsers.add_parser(
        "datasize", help="Generate dataset size vs metrics plots"
    )
    ds_parser.add_argument(
        "--variability",
        choices=["with", "without", "both"],
        default="both",
        help="Which variability condition to generate plots for",
    )

    k_parser = subparsers.add_parser("kendall", help="Generate Kendall tau histograms")
    k_parser.add_argument(
        "--variability",
        choices=["with", "without", "both"],
        default="both",
        help="Which variability condition to generate plots for",
    )

    subparsers.add_parser("all", help="Run all analyses")

    args = parser.parse_args()

    command_funcs = {
        "find": find_command,
        "confusion-matrices": confusion_matrices_command,
        "boxplots": boxplots_command,
        "datasize": datasize_command,
        "kendall": kendall_command,
        "all": all_command,
    }

    if args.command in command_funcs:
        # Convert args to dictionary for commands that take keyword arguments
        args_dict = vars(args)

        # Remove 'command' from args_dict to avoid unexpected keyword argument errors
        command_name = args_dict.pop("command")

        # Call the function with keyword arguments - empty kwargs for commands with no args
        command_funcs[command_name](**args_dict)
    else:
        parser.print_help()
        sys.exit(1)


if __name__ == "__main__":
    main()
