import argparse
import statistics
import os
import pandas as pd

def load_expected_labels_from_csv(csv_file: str, label_column: str = "label") -> list:
    """
    Loads expected labels directly from the test CSV file.

    Args:
        csv_file (str): Path to the test CSV file.
        label_column (str, optional): Name of the label column in the CSV. Defaults to "label".

    Returns:
        list: List of expected labels.
    """
    if not os.path.exists(csv_file):
        print(f"Test CSV file {csv_file} does not exist.")
        return []

    try:
        df = pd.read_csv(csv_file)
        if label_column not in df.columns:
            print(f"CSV file does not contain the '{label_column}' column.")
            return []
        expected_labels = df[label_column].tolist()
        print(f"Loaded {len(expected_labels)} expected labels from {csv_file}")
        return expected_labels
    except Exception as e:
        print(f"Error reading {csv_file}: {e}")
        return []

def load_predicted_labels(predicted_file: str) -> list:
    """
    Loads predicted labels from a text file, one label per line.

    Args:
        predicted_file (str): Path to the predicted labels text file.

    Returns:
        list: List of predicted labels.
    """
    if not os.path.exists(predicted_file):
        print(f"Predicted labels file {predicted_file} does not exist.")
        return []

    try:
        with open(predicted_file, "r") as f:
            predicted_labels = [line.strip() for line in f if line.strip()]
        print(f"Loaded {len(predicted_labels)} predicted labels from {predicted_file}")
        return predicted_labels
    except Exception as e:
        print(f"Error reading {predicted_file}: {e}")
        return []

def calculate_accuracy(expected: list, predicted: list) -> float:
    """
    Calculates the overall accuracy.

    Args:
        expected (list): List of expected labels.
        predicted (list): List of predicted labels.

    Returns:
        float: Accuracy percentage.
    """
    if len(expected) != len(predicted):
        print(f"Mismatch in number of expected ({len(expected)}) and predicted ({len(predicted)}) labels.")
        return 0.0

    matching = sum(1 for e, p in zip(expected, predicted) if str(e).strip() == str(p).strip())
    accuracy = (matching / len(expected)) * 100
    return accuracy

def calculate_chunk_accuracies(expected: list, predicted: list, chunk_size: int = 5) -> list:
    """
    Calculates accuracy for each chunk of predictions.

    Args:
        expected (list): List of expected labels.
        predicted (list): List of predicted labels.
        chunk_size (int, optional): Number of predictions per chunk. Defaults to 5.

    Returns:
        list: List of accuracy percentages per chunk.
    """
    total_predictions = len(predicted)
    if total_predictions % chunk_size != 0:
        print(f"Number of predicted labels ({total_predictions}) is not a multiple of chunk size ({chunk_size}). Some labels will be ignored.")
        total_chunks = total_predictions // chunk_size
    else:
        total_chunks = total_predictions // chunk_size

    accuracies = []

    for i in range(total_chunks):
        start_idx = i * chunk_size
        end_idx = start_idx + chunk_size
        chunk_expected = expected[start_idx:end_idx]
        chunk_predicted = predicted[start_idx:end_idx]
        matching = sum(1 for e, p in zip(chunk_expected, chunk_predicted) if str(e).strip() == str(p).strip())
        chunk_accuracy = (matching / chunk_size) * 100
        accuracies.append(chunk_accuracy)
        print(f"Chunk {i+1}: {matching}/{chunk_size} correct ({chunk_accuracy:.2f}%)")

    return accuracies

def main(test_csv: str, predicted_file: str, chunk_size: int = 5):
    """
    Main function to calculate and report accuracies.

    Args:
        test_csv (str): Path to the test CSV file containing true labels.
        predicted_file (str): Path to the predicted labels text file.
        chunk_size (int, optional): Number of predictions per chunk for calculating mean accuracy. Defaults to 5.
    """
    expected = load_expected_labels_from_csv(test_csv)
    predicted = load_predicted_labels(predicted_file)

    if not expected or not predicted:
        print("Expected or predicted labels are empty. Exiting.")
        return

    overall_accuracy = calculate_accuracy(expected, predicted)
    print(f"\nOverall Accuracy: {overall_accuracy:.2f}%")

    chunk_accuracies = calculate_chunk_accuracies(expected, predicted, chunk_size)
    if chunk_accuracies:
        mean_accuracy = statistics.mean(chunk_accuracies)
        std_deviation = statistics.stdev(chunk_accuracies) if len(chunk_accuracies) > 1 else 0.0
        print(f"\nMean Accuracy per {chunk_size} predictions: {mean_accuracy:.2f}%")
        print(f"Standard Deviation: {std_deviation:.2f}%")
    else:
        print("No chunk accuracies calculated.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Calculate Accuracy of Predictions")
    parser.add_argument(
        "--test_csv",
        type=str,
        required=True,
        help="Path to the test CSV file containing true labels.",
    )
    parser.add_argument(
        "--predicted_file",
        type=str,
        required=True,
        help="Path to the predicted labels text file.",
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=5,
        help="Number of predictions per chunk for calculating mean accuracy.",
    )

    args = parser.parse_args()

    main(args.test_csv, args.predicted_file, args.chunk_size)
