import argparse
import json
import logging
import os
from collections import Counter

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def read_json_files(directory):
    """Reads all JSON files in the given directory."""
    data_list = []
    for filename in os.listdir(directory):
        if filename.endswith(".json"):
            file_path = os.path.join(directory, filename)
            with open(file_path, "r", encoding="utf-8") as file:
                try:
                    data = json.load(file)
                    data_list.append(data)
                except json.JSONDecodeError:
                    logging.error(f"Error decoding JSON file: {filename}")
    return data_list


def parse_response(response, file_path):
    """Extracts and validates the response content as a JSON object and flattens it into the main JSON structure."""
    extracted_response = {}
    try:
        choices = response.get("choices", [])
        if choices:
            message = choices[0].get("message", {})
            content = message.get("content", "{}")

            # Ensure content is valid JSON
            parsed_content = json.loads(content)
            extracted_response.update(parsed_content)
    except (json.JSONDecodeError, IndexError, KeyError) as e:
        logging.error(f"Error parsing response: {e} {file_path}")

    # Flattening the entire response structure into main object
    # for key, value in response.items():
    #     if key != "choices":
    #         extracted_response[key] = value

    return extracted_response


def process_data(directory):
    """Processes JSON data and creates a new structured JSON object."""
    processed_data = []
    for filename in os.listdir(directory):
        if filename.endswith(".json"):
            file_path = os.path.join(directory, filename)
            with open(file_path, "r", encoding="utf-8") as file:
                try:
                    data = json.load(file)

                    extracted_info = {
                        "q_id": data.get("q_id", ""),
                        "country": data.get("country", ""),
                        "category": data.get("category", ""),
                        "subcategory": data.get("subcategory", ""),
                        "topic": data.get("topic", ""),
                        "query": data.get("query", ""),
                        "topic_rank": data.get("topic_rank", ""),
                        "ranked_index": data.get("ranked_index", ""),
                        "image_id": data.get("image_id", ""),
                        "image_url": data.get("image_url", ""),
                        "image_info": data.get("image_info", ""),
                        "image_path": data.get("image_path", ""),
                    }
                    # Merge response data into the main object
                    extracted_info.update(
                        parse_response(data.get("response", {}), file_path)
                    )
                    processed_data.append(extracted_info)
                except json.JSONDecodeError:
                    logging.error(f"Error decoding JSON file: {file_path}")
    return processed_data


def calculate_stats(processed_data):
    """Calculates statistics for image_category, status, and presence of extracted_text."""
    image_category_counter = Counter()
    status_counter = Counter()
    extracted_text_presence = {"exists": 0, "missing": 0}

    for entry in processed_data:
        image_category = entry.get("image_category", "Unknown")
        status = entry.get("status", "Unknown")
        extracted_text = entry.get("extracted_text", "")

        image_category_counter[image_category.lower()] += 1
        status_counter[status.lower()] += 1

        if extracted_text:
            extracted_text_presence["exists"] += 1
        else:
            extracted_text_presence["missing"] += 1

    logging.info(f"Image Category Distribution: {dict(image_category_counter)}")
    logging.info(f"Status Distribution: {dict(status_counter)}")
    logging.info(f"Extracted Text Presence: {extracted_text_presence}")

    for key, value in image_category_counter.items():
        print(f"{key}\t{value}")

    print("\n")
    for key, value in status_counter.items():
        print(f"{key}\t{value}")
    print("\n")

    for key, value in extracted_text_presence.items():
        print(f"{key}\t{value}")
    print("\n")


def save_to_jsonl(data, output_file):
    """Saves processed data to a JSONL file."""
    with open(output_file, "w", encoding="utf-8") as outfile:
        for entry in data:
            outfile.write(json.dumps(entry, ensure_ascii=False) + "\n")


def main():
    parser = argparse.ArgumentParser(
        description="Process JSON files and extract relevant information."
    )
    parser.add_argument("-i", "--input_dir", help="Directory containing JSON files.")
    parser.add_argument("-o", "--output_file", help="Output JSONL file.")
    args = parser.parse_args()

    logging.info("Reading JSON files...")
    # data_list = read_json_files(args.input_dir)

    logging.info("Processing data...")
    processed_data = process_data(args.input_dir)

    logging.info("Calculating statistics...")
    calculate_stats(processed_data)

    logging.info("Saving processed data...")
    save_to_jsonl(processed_data, args.output_file)

    logging.info(f"Processed data saved to '{args.output_file}'")


if __name__ == "__main__":
    main()
