# analyze_algorithms.py

import re
import matplotlib.pyplot as plt
import click
from typing import List, Dict, Optional

plt.style.use("ggplot")


class AlgorithmData:
    """Class to store data for a single algorithm."""

    def __init__(self, name: str):
        self.name = name
        self.dimensions = []
        self.mean_total_durations = []
        self.std_total_durations = []

    def add_data(
        self, dimension: int, mean_duration: float, std_duration: Optional[float]
    ):
        self.dimensions.append(dimension)
        self.mean_total_durations.append(mean_duration)
        self.std_total_durations.append(std_duration)

    def get_sorted_data(self):
        """Returns data sorted by dimensions."""
        sorted_data = sorted(
            zip(self.dimensions, self.mean_total_durations, self.std_total_durations)
        )
        dimensions_sorted, mean_durations_sorted, std_durations_sorted = zip(
            *sorted_data
        )
        return dimensions_sorted, mean_durations_sorted, std_durations_sorted


class DataParser:
    """Class to parse input data from the provided file."""

    def __init__(self, file_path: str):
        self.file_path = file_path
        self.algorithms_data: Dict[str, AlgorithmData] = {}

    def parse(self):
        with open(self.file_path, "r") as file:
            lines = file.readlines()

        current_algorithm = None

        # Regular expression patterns
        algorithm_pattern = re.compile(r"^([A-Za-z-]+)$")
        dimension_pattern = re.compile(
            r"Dimension (\d+):.*Mean Total Duration = ([\d\.]+) s(?:, Std Dev Total Duration = ([\d\.]+) s)?"
        )

        for line in lines:
            line = line.strip()
            if not line:
                continue
            # Check for algorithm name
            match_algo = algorithm_pattern.match(line)
            if match_algo and not line.startswith("Dimension"):
                current_algorithm = match_algo.group(1).strip()
                if current_algorithm not in self.algorithms_data:
                    self.algorithms_data[current_algorithm] = AlgorithmData(
                        current_algorithm
                    )
            # Check for dimension data
            match_dimension = dimension_pattern.match(line)
            if match_dimension and current_algorithm:
                dimension = int(match_dimension.group(1))
                mean_total_duration = float(match_dimension.group(2))
                std_total_duration = match_dimension.group(3)
                if std_total_duration:
                    std_total_duration = float(std_total_duration)
                else:
                    std_total_duration = None
                # Add data to current algorithm
                self.algorithms_data[current_algorithm].add_data(
                    dimension, mean_total_duration, std_total_duration
                )


class Plotter:
    """Class to handle plotting of the data."""

    def __init__(self, algorithms_data: Dict[str, AlgorithmData]):
        self.algorithms_data = algorithms_data

    def plot_mean_total_duration(self):
        plt.figure(figsize=(10, 6))
        for algo_data in self.algorithms_data.values():
            dimensions, mean_durations, _ = algo_data.get_sorted_data()
            plt.plot(dimensions, mean_durations, marker="o", label=algo_data.name)
        plt.xlabel("Dimension")
        plt.ylabel("Mean Total Duration (s)")
        plt.title("Mean Total Duration vs Dimension")
        plt.legend()
        plt.grid(True)
        plt.xticks(
            sorted(
                set(
                    d for algo in self.algorithms_data.values() for d in algo.dimensions
                )
            )
        )
        plt.tight_layout()
        plt.show()

    def plot_std_total_duration(self):
        plt.figure(figsize=(10, 6))
        for algo_data in self.algorithms_data.values():
            dimensions, _, std_durations = algo_data.get_sorted_data()
            if any(std_durations):
                plt.plot(dimensions, std_durations, marker="o", label=algo_data.name)
        plt.xlabel("Dimension")
        plt.ylabel("Std Dev of Total Duration (s)")
        plt.title("Standard Deviation of Total Duration vs Dimension")
        plt.legend()
        plt.grid(True)
        plt.xticks(
            sorted(
                set(
                    d for algo in self.algorithms_data.values() for d in algo.dimensions
                )
            )
        )
        plt.tight_layout()
        plt.show()


@click.command()
@click.option("--file", "file_path", required=True, help="Path to the input data file.")
def main(file_path):
    """Main entry point for the script."""
    parser = DataParser(file_path)
    parser.parse()

    plotter = Plotter(parser.algorithms_data)
    plotter.plot_mean_total_duration()
    plotter.plot_std_total_duration()


if __name__ == "__main__":
    main()
