import argparse
import sys
import re
import random
import os


def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Randomly sample n lines from a concept .txt file, with optional multiple sets."
    )
    parser.add_argument(
        "--concepts", "-c",
        type=str,
        required=True,
        help="Path to the concepts .txt file."
    )
    parser.add_argument(
        "--num_samples", "-n",
        type=int,
        required=True,
        help="Number of lines to sample per set."
    )
    parser.add_argument(
        "--num_sets", "-k",
        type=int,
        default=1,
        help="Number of distinct sets to sample. Default is 1."
    )
    parser.add_argument(
        "--exclude", "-e",
        type=str,
        nargs='+',
        default=None,
        help="Paths to one or more files containing concepts to exclude, one per line. Default is None."
    )
    return parser.parse_args()


def read_file(file_path):
    try:
        with open(file_path, 'r') as file:
            lines = file.readlines()
            # Strip newline characters and ignore empty lines
            lines = [sanitize_line(line).strip() for line in lines if line.strip()]
        return lines
    except FileNotFoundError:
        print(f"Error: The file '{file_path}' does not exist.")
        sys.exit(1)
    except IOError as e:
        print(f"Error reading file '{file_path}': {e}")
        sys.exit(1)


def sample_distinct_sets(lines, n, k):
    total = len(lines)
    if n * k > total:
        print(f"Error: Requested {k} sets of {n} samples, but the file only contains {total} unique lines.")
        sys.exit(1)

    sampled_sets = []
    available_lines = lines.copy()

    for i in range(k):
        sampled = random.sample(available_lines, n)
        sampled_sets.append(sampled)
        # Remove the sampled lines from available pool
        available_lines = [line for line in available_lines if line not in sampled]

    return sampled_sets


def write_sampled_sets(sampled_sets, original_file_path, n, k):
    directory, filename = os.path.split(original_file_path)
    name, ext = os.path.splitext(filename)

    for i, sampled in enumerate(sampled_sets):
        new_filename = f"{name}_random_{n}_set_{i + 1}{ext}"
        new_file_path = os.path.join(directory, new_filename)

        try:
            with open(new_file_path, 'w') as file:
                for line in sampled:
                    file.write(f"{line}\n")
            print(f"Sampled {len(sampled)} lines written to '{new_file_path}'.")
        except IOError as e:
            print(f"Error writing to file '{new_file_path}': {e}")
            sys.exit(1)


def sanitize_line(line):
    """
    Sanitizes a line by:
    - Removing any text within square brackets (e.g., [m.06cfrk])
    - Removing numbers, commas, or other non-alphabetic characters
    - Replacing underscores with spaces
    - Converting to lowercase

    Args:
        line (str): The line to sanitize.

    Returns:
        str: The sanitized string.
    """
    # Remove text inside square brackets
    line = re.sub(r"\[.*?\]", "", line)
    # Remove non-alphabetic characters (including digits, punctuation)
    line = re.sub(r"[^a-zA-Z\s_]", "", line)
    # Replace underscores with spaces
    line = line.replace("_", " ")
    # Convert to lowercase and strip leading/trailing spaces
    return line.strip()


def read_files(file_paths):
    """Read and combine lines from multiple files."""
    lines = set()
    for file_path in file_paths:
        try:
            with open(file_path, 'r') as file:
                # Strip newline characters and ignore empty lines
                lines.update(
                    sanitize_line(line).strip() for line in file.readlines() if line.strip()
                )
        except FileNotFoundError:
            print(f"Error: The file '{file_path}' does not exist.")
            sys.exit(1)
        except IOError as e:
            print(f"Error reading file '{file_path}': {e}")
            sys.exit(1)
    return lines


def main():
    args = parse_arguments()
    concepts_file = args.concepts
    exclude_files = args.exclude
    n = args.num_samples
    k = args.num_sets

    if n <= 0 or k <= 0:
        print("Error: Both the number of samples 'n' and the number of sets 'k' must be positive integers.")
        sys.exit(1)

    # Read concepts and exclude lines (if provided)
    lines = read_file(concepts_file)
    exclude_lines = read_files(exclude_files) if exclude_files else set()

    # Remove excluded lines from the pool of concepts
    lines = [line for line in lines if line not in exclude_lines]

    if not lines:
        print("Error: No concepts available for sampling after applying exclusions.")
        sys.exit(1)

    sampled_sets = sample_distinct_sets(lines, n, k)
    write_sampled_sets(sampled_sets, concepts_file, n, k)


if __name__ == "__main__":
    main()
