# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
import os
from datetime import date
from pathlib import Path

from tabulate import tabulate


MAX_LEN_MESSAGE = 2900  # Slack endpoint has a limit of 3001 characters

parser = argparse.ArgumentParser()
parser.add_argument("--slack_channel_name", default="trl-push-ci")

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


def process_log_file(log):
    failed_tests = []
    passed_tests = []
    section_num_failed = 0

    try:
        with open(log) as f:
            for line in f:
                try:
                    data = json.loads(line)
                    test_name = data.get("nodeid", "")
                    duration = f'{data["duration"]:.4f}' if "duration" in data else "N/A"
                    outcome = data.get("outcome", "")

                    if test_name:
                        if outcome == "failed":
                            section_num_failed += 1
                            failed_tests.append([test_name, duration, log.stem.split("_")[0]])
                        else:
                            passed_tests.append([test_name, duration, log.stem.split("_")[0]])
                except json.JSONDecodeError as e:
                    logging.warning(f"Could not decode line in {log}: {e}")

    except FileNotFoundError as e:
        logging.error(f"Log file {log} not found: {e}")
    except Exception as e:
        logging.error(f"Error processing log file {log}: {e}")

    return failed_tests, passed_tests, section_num_failed


def main(slack_channel_name):
    group_info = []
    total_num_failed = 0
    total_empty_files = []

    log_files = list(Path().glob("*.log"))
    if not log_files:
        logging.info("No log files found.")
        return

    for log in log_files:
        failed, passed, section_num_failed = process_log_file(log)
        empty_file = not failed and not passed

        total_num_failed += section_num_failed
        total_empty_files.append(empty_file)
        group_info.append([str(log), section_num_failed, failed])

        # Clean up log file
        try:
            os.remove(log)
        except OSError as e:
            logging.warning(f"Could not remove log file {log}: {e}")

    # Prepare Slack message payload
    payload = [
        {
            "type": "header",
            "text": {"type": "plain_text", "text": f"🤗 Results of the {os.environ.get('TEST_TYPE', '')} TRL tests."},
        },
    ]

    if total_num_failed > 0:
        message = ""
        for name, num_failed, failed_tests in group_info:
            if num_failed > 0:
                message += f"*{name}: {num_failed} failed test(s)*\n"
                failed_table = [
                    test[0].split("::")[:2] + [test[0].split("::")[-1][:30] + ".."] for test in failed_tests
                ]
                message += (
                    "\n```\n"
                    + tabulate(failed_table, headers=["Test Location", "Test Name"], tablefmt="grid")
                    + "\n```\n"
                )

            if any(total_empty_files):
                message += f"\n*{name}: Warning! Empty file - check GitHub action job*\n"

        # Logging
        logging.info(f"Total failed tests: {total_num_failed}")
        print(f"### {message}")

        if len(message) > MAX_LEN_MESSAGE:
            message = (
                f"❌ There are {total_num_failed} failed tests in total! Please check the action results directly."
            )

        payload.append({"type": "section", "text": {"type": "mrkdwn", "text": message}})
        payload.append(
            {
                "type": "section",
                "text": {"type": "mrkdwn", "text": "*For more details:*"},
                "accessory": {
                    "type": "button",
                    "text": {"type": "plain_text", "text": "Check Action results"},
                    "url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}",
                },
            }
        )
        payload.append(
            {
                "type": "context",
                "elements": [
                    {
                        "type": "plain_text",
                        "text": f"On Push main {os.environ.get('TEST_TYPE')} results for {date.today()}",
                    }
                ],
            }
        )

        # Send to Slack
        from slack_sdk import WebClient

        slack_client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
        slack_client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)

    else:
        payload.append(
            {
                "type": "section",
                "text": {
                    "type": "plain_text",
                    "text": "✅ No failures! All tests passed successfully.",
                    "emoji": True,
                },
            }
        )
        logging.info("All tests passed. No errors detected.")


if __name__ == "__main__":
    args = parser.parse_args()
    main(args.slack_channel_name)
