import os
import re
import click
from abc import ABC, abstractmethod
from datetime import datetime
from statistics import mean, stdev
from typing import List, Dict, Tuple, Optional


class Environment:
    """Represents an environment with function number, dimension, and instance number."""

    def __init__(self, function_number: int, dimension: int, instance_number: int):
        self.function_number = function_number
        self.dimension = dimension
        self.instance_number = instance_number
        self.intervals: List[float] = []
        self.file_durations: List[float] = []

    def add_interval(self, interval: float):
        """Adds a time interval to the environment's list of intervals."""
        self.intervals.append(interval)

    def add_file_duration(self, duration: float):
        """Adds total duration of a file to the environment's list of durations."""
        self.file_durations.append(duration)

    def get_mean_interval(self) -> float:
        """Calculates the mean of the intervals."""
        if self.intervals:
            return mean(self.intervals)
        return 0.0

    def get_std_dev_interval(self) -> float:
        """Calculates the standard deviation of the intervals."""
        if len(self.intervals) > 1:
            return stdev(self.intervals)
        return 0.0

    def get_total_duration(self) -> float:
        """Calculates the total duration from all files."""
        return sum(self.file_durations)

    def get_mean_total_duration(self) -> float:
        """Calculates the mean total duration per file."""
        if self.file_durations:
            return mean(self.file_durations)
        return 0.0

    def get_std_dev_total_duration(self) -> float:
        """Calculates the standard deviation of total durations."""
        if len(self.file_durations) > 1:
            return stdev(self.file_durations)
        return 0.0


class LogParser(ABC):
    """Abstract base class for log parsers."""

    @abstractmethod
    def parse_environment_line(self, lines: List[str]) -> Optional[Environment]:
        """Parses environment information from the log file lines."""
        pass

    @abstractmethod
    def parse_timestamp_line(self, line: str) -> Optional[datetime]:
        """Parses timestamps from a log line."""
        pass


class OriginalLogParser(LogParser):
    """Parser for the original log formats."""

    def parse_environment_line(self, lines: List[str]) -> Optional[Environment]:
        for line in lines:
            env_pattern = re.compile(r"env=coco (\d+)-(\d+)-(\d+)")
            match = env_pattern.search(line)
            if match:
                function_number = int(match.group(1))
                dimension = int(match.group(2))
                instance_number = int(match.group(3))
                return Environment(function_number, dimension, instance_number)
        return None

    def parse_timestamp_line(self, line: str) -> Optional[datetime]:
        pattern = re.compile(
            r"INFO - (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}) - Exploring new data points.*"
        )
        match = pattern.search(line)
        if match:
            timestamp_str = match.group(1)
            return datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S.%f")
        return None


class CmaLogParser(LogParser):
    """Parser for the CMA log format."""

    def parse_environment_line(self, lines: List[str]) -> Optional[Environment]:
        for line in lines:
            env_pattern = re.compile(
                r"INFO - .* - CMA - new iteration for coco (\d+)-(\d+)-(\d+)"
            )
            match = env_pattern.search(line)
            if match:
                function_number = int(match.group(1))
                dimension = int(match.group(2))
                instance_number = int(match.group(3))
                return Environment(function_number, dimension, instance_number)
        return None

    def parse_timestamp_line(self, line: str) -> Optional[datetime]:
        pattern = re.compile(
            r"INFO - (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}) - CMA - new iteration for coco .*"
        )
        match = pattern.search(line)
        if match:
            timestamp_str = match.group(1)
            return datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S.%f")
        return None


class ParserFactory:
    """Factory class to get parsers based on file name."""

    def __init__(self):
        # Map file type to parser classes
        self.parser_mapping = {
            "cma": CmaLogParser,
            "hegl": OriginalLogParser,
            "hegl_norm": OriginalLogParser,
            "egl": OriginalLogParser,
            "egl_scheduler": OriginalLogParser,
            # Add more mappings as needed
        }

    def get_parsers_for_file(self, filename: str) -> List[LogParser]:
        parsers = []
        for key, parser_class in self.parser_mapping.items():
            if key in filename:
                parsers.append(parser_class())
        if not parsers:
            print(f"No parser found for file {filename}.")
        return parsers


class LogFileProcessor:
    """Processes a single log file to extract environment info and time intervals."""

    def __init__(self, file_path: str, parsers: List[LogParser]):
        self.file_path = file_path
        self.parsers = parsers
        self.environment: Environment = None
        self.timestamps: List[datetime] = []
        self.first_timestamp: Optional[datetime] = None
        self.last_timestamp: Optional[datetime] = None

    def contains_error(self) -> bool:
        """Checks if the file contains the specific error string."""
        with open(self.file_path, "r") as file:
            content = file.read()
            if "RuntimeError: No CUDA GPUs are available" in content:
                return True
        return False

    def process(self) -> bool:
        """Processes the log file to extract necessary information."""
        if self.contains_error():
            print(
                f"Skipping file {os.path.basename(self.file_path)} due to CUDA error."
            )
            return False

        with open(self.file_path, "r") as file:
            lines = file.readlines()
            if not lines:
                return False

            # Try each parser until one succeeds
            for parser in self.parsers:
                parser_name = parser.__class__.__name__
                print(
                    f"Trying parser {parser_name} for file {os.path.basename(self.file_path)}"
                )
                env = parser.parse_environment_line(lines)
                if env:
                    self.environment = env
                    print(
                        f"Parser {parser_name} succeeded for file {os.path.basename(self.file_path)}"
                    )
                    # Parse timestamps
                    for line in lines:
                        timestamp = parser.parse_timestamp_line(line)
                        if timestamp:
                            self.timestamps.append(timestamp)

                    if self.timestamps:
                        # Sort timestamps just in case
                        self.timestamps.sort()
                        self.first_timestamp = self.timestamps[0]
                        self.last_timestamp = self.timestamps[-1]

                        # Compute intervals and add to environment
                        if len(self.timestamps) >= 2:
                            for i in range(1, len(self.timestamps)):
                                delta = (
                                    self.timestamps[i] - self.timestamps[i - 1]
                                ).total_seconds()
                                self.environment.add_interval(delta)

                        # Compute total duration of the file
                        total_duration = (
                            self.last_timestamp - self.first_timestamp
                        ).total_seconds()
                        self.environment.add_file_duration(total_duration)
                    else:
                        # No timestamps found, set durations to zero
                        self.first_timestamp = None
                        self.last_timestamp = None
                    return True  # Successfully parsed with this parser
                else:
                    print(
                        f"Parser {parser_name} did not match for file {os.path.basename(self.file_path)}"
                    )

        # If none of the parsers could parse the file
        print(
            f"Could not parse file {os.path.basename(self.file_path)} with available parsers."
        )
        return False


class LogDirectoryProcessor:
    """Processes all log files in directories and aggregates statistics."""

    def __init__(self, directory_paths: List[str], parser_factory: ParserFactory):
        self.directory_paths = directory_paths
        self.parser_factory = parser_factory
        self.environments: Dict[Tuple[int, int, int], Environment] = {}
        self.dimension_intervals: Dict[int, List[float]] = {}
        self.function_intervals: Dict[int, List[float]] = {}
        self.dimension_durations: Dict[int, List[float]] = {}
        self.function_durations: Dict[int, List[float]] = {}
        self.dimension_file_counts: Dict[int, int] = {}
        self.total_file_count: int = 0
        self.all_intervals: List[float] = []
        self.all_total_durations: List[float] = []

    def process_logs(self):
        """Processes all log files in the directories."""
        for directory_path in self.directory_paths:
            print(f"\nProcessing directory: {directory_path}")
            for filename in os.listdir(directory_path):
                file_path = os.path.join(directory_path, filename)
                if not os.path.isfile(file_path):
                    continue  # Skip directories

                parsers = self.parser_factory.get_parsers_for_file(filename)
                if not parsers:
                    print(f"Skipping file {filename}: No appropriate parser found.")
                    continue

                processor = LogFileProcessor(file_path, parsers)

                print(f"Processing file: {filename}")
                try:
                    processed = processor.process()
                except Exception as e:
                    print(f"Error processing file {filename}: {e}")
                    continue

                if not processed:
                    print(f"Failed to process file {filename} with available parsers.")
                    continue  # Skip files that couldn't be parsed or had errors

                print(f"Successfully processed file: {filename}")

                env = processor.environment
                key = (env.function_number, env.dimension, env.instance_number)

                # Count the file per dimension
                if env.dimension not in self.dimension_file_counts:
                    self.dimension_file_counts[env.dimension] = 0
                self.dimension_file_counts[env.dimension] += 1
                self.total_file_count += 1

                # Combine environments
                if key not in self.environments:
                    self.environments[key] = env
                else:
                    # Combine intervals and durations if the same environment appears in multiple files
                    self.environments[key].intervals.extend(env.intervals)
                    self.environments[key].file_durations.extend(env.file_durations)

                # Collect intervals and durations for dimensions
                if env.dimension not in self.dimension_intervals:
                    self.dimension_intervals[env.dimension] = []
                    self.dimension_durations[env.dimension] = []
                self.dimension_intervals[env.dimension].extend(env.intervals)
                self.dimension_durations[env.dimension].extend(env.file_durations)

                # Collect intervals and durations for function numbers
                if env.function_number not in self.function_intervals:
                    self.function_intervals[env.function_number] = []
                    self.function_durations[env.function_number] = []
                self.function_intervals[env.function_number].extend(env.intervals)
                self.function_durations[env.function_number].extend(env.file_durations)

                # Collect all intervals and durations for overall statistics
                self.all_intervals.extend(env.intervals)
                self.all_total_durations.extend(env.file_durations)

    def calculate_statistics(self):
        """Calculates mean and standard deviation for environments, dimensions, functions, and overall."""
        # For each environment
        env_stats = {}
        for key, env in self.environments.items():
            mean_interval = env.get_mean_interval()
            std_dev_interval = env.get_std_dev_interval()
            mean_total_duration = env.get_mean_total_duration()
            std_dev_total_duration = env.get_std_dev_total_duration()
            env_stats[key] = {
                "mean_interval": mean_interval,
                "std_dev_interval": std_dev_interval,
                "mean_total_duration": mean_total_duration,
                "std_dev_total_duration": std_dev_total_duration,
            }

        # For each dimension
        dimension_stats = {}
        for dimension, intervals in self.dimension_intervals.items():
            durations = self.dimension_durations[dimension]
            if intervals:
                mean_interval = mean(intervals)
                std_dev_interval = stdev(intervals) if len(intervals) > 1 else 0.0
            else:
                mean_interval = 0.0
                std_dev_interval = 0.0
            # Calculate mean and std dev of total durations
            mean_total_duration = mean(durations) if durations else 0.0
            std_dev_total_duration = stdev(durations) if len(durations) > 1 else 0.0
            file_count = self.dimension_file_counts.get(dimension, 0)
            dimension_stats[dimension] = {
                "mean_interval": mean_interval,
                "std_dev_interval": std_dev_interval,
                "mean_total_duration": mean_total_duration,
                "std_dev_total_duration": std_dev_total_duration,
                "file_count": file_count,
            }

        # For each function number
        function_stats = {}
        for function_number, intervals in self.function_intervals.items():
            durations = self.function_durations[function_number]
            if intervals:
                mean_interval = mean(intervals)
                std_dev_interval = stdev(intervals) if len(intervals) > 1 else 0.0
            else:
                mean_interval = 0.0
                std_dev_interval = 0.0
            # Calculate mean and std dev of total durations
            mean_total_duration = mean(durations) if durations else 0.0
            std_dev_total_duration = stdev(durations) if len(durations) > 1 else 0.0
            file_count = sum(
                1 for key in self.environments if key[0] == function_number
            )
            function_stats[function_number] = {
                "mean_interval": mean_interval,
                "std_dev_interval": std_dev_interval,
                "mean_total_duration": mean_total_duration,
                "std_dev_total_duration": std_dev_total_duration,
                "file_count": file_count,
            }

        # Overall statistics
        overall_stats = {}
        if self.all_intervals:
            overall_mean_interval = mean(self.all_intervals)
            overall_std_dev_interval = (
                stdev(self.all_intervals) if len(self.all_intervals) > 1 else 0.0
            )
        else:
            overall_mean_interval = 0.0
            overall_std_dev_interval = 0.0

        if self.all_total_durations:
            overall_mean_total_duration = mean(self.all_total_durations)
            overall_std_dev_total_duration = (
                stdev(self.all_total_durations)
                if len(self.all_total_durations) > 1
                else 0.0
            )
        else:
            overall_mean_total_duration = 0.0
            overall_std_dev_total_duration = 0.0

        overall_stats = {
            "mean_interval": overall_mean_interval,
            "std_dev_interval": overall_std_dev_interval,
            "mean_total_duration": overall_mean_total_duration,
            "std_dev_total_duration": overall_std_dev_total_duration,
        }

        return env_stats, dimension_stats, function_stats, overall_stats

    def get_total_file_count(self) -> int:
        """Returns the total number of files processed."""
        return self.total_file_count

    def get_dimension_file_counts(self) -> Dict[int, int]:
        """Returns the number of files processed per dimension."""
        return self.dimension_file_counts


@click.command()
@click.option(
    "--directory_path",
    "directory_paths",
    multiple=True,
    required=True,
    help="Path(s) to the directory containing log files. Can be used multiple times.",
)
def main(directory_paths):
    # Initialize parser factory
    parser_factory = ParserFactory()

    log_processor = LogDirectoryProcessor(list(directory_paths), parser_factory)
    log_processor.process_logs()
    (
        env_stats,
        dimension_stats,
        function_stats,
        overall_stats,
    ) = log_processor.calculate_statistics()

    total_file_count = log_processor.get_total_file_count()
    dimension_file_counts = log_processor.get_dimension_file_counts()

    # Print total number of files processed
    print(f"\nTotal number of files processed: {total_file_count}")

    # Print number of files processed per dimension
    print("\nNumber of files processed per dimension:")
    for dimension, count in dimension_file_counts.items():
        print(f"Dimension {dimension}: {count} files")

    # Print environment statistics
    print("\nEnvironment Statistics:")
    for key, stats in env_stats.items():
        function_number, dimension, instance_number = key
        mean_interval = stats["mean_interval"]
        std_dev_interval = stats["std_dev_interval"]
        mean_total_duration = stats["mean_total_duration"]
        std_dev_total_duration = stats["std_dev_total_duration"]
        print(
            f"Env (Function {function_number}, Dimension {dimension}, Instance {instance_number}): "
            f"Mean Interval = {mean_interval:.2f} s, Std Dev Interval = {std_dev_interval:.2f} s, "
            f"Mean Total Duration = {mean_total_duration:.2f} s, Std Dev Total Duration = {std_dev_total_duration:.2f} s"
        )

    # Print dimension statistics
    print("\nDimension Statistics:")
    for dimension, stats in dimension_stats.items():
        mean_interval = stats["mean_interval"]
        std_dev_interval = stats["std_dev_interval"]
        mean_total_duration = stats["mean_total_duration"]
        std_dev_total_duration = stats["std_dev_total_duration"]
        file_count = stats["file_count"]
        print(
            f"Dimension {dimension}: Mean Interval = {mean_interval:.2f} s, "
            f"Std Dev Interval = {std_dev_interval:.2f} s, Mean Total Duration = {mean_total_duration:.2f} s, "
            f"Std Dev Total Duration = {std_dev_total_duration:.2f} s, Files Processed = {file_count}"
        )

    # Print function number statistics
    print("\nFunction Number Statistics:")
    for function_number, stats in function_stats.items():
        mean_interval = stats["mean_interval"]
        std_dev_interval = stats["std_dev_interval"]
        mean_total_duration = stats["mean_total_duration"]
        std_dev_total_duration = stats["std_dev_total_duration"]
        file_count = stats["file_count"]
        print(
            f"Function {function_number}: Mean Interval = {mean_interval:.2f} s, "
            f"Std Dev Interval = {std_dev_interval:.2f} s, Mean Total Duration = {mean_total_duration:.2f} s, "
            f"Std Dev Total Duration = {std_dev_total_duration:.2f} s, Files Processed = {file_count}"
        )

    # Print overall statistics
    print("\nOverall Statistics:")
    print(f"Mean Interval over all files: {overall_stats['mean_interval']:.2f} s")
    print(f"Std Dev Interval over all files: {overall_stats['std_dev_interval']:.2f} s")
    print(
        f"Mean Total Duration over all files: {overall_stats['mean_total_duration']:.2f} s"
    )
    print(
        f"Std Dev Total Duration over all files: {overall_stats['std_dev_total_duration']:.2f} s"
    )


if __name__ == "__main__":
    main()
