#!/usr/bin/env python3

# NOTE: This is forked from https://github.com/huggingface/knockknock
from typing import List
import os
import datetime
import traceback
import functools
import json
import socket
import requests

DATE_FORMAT = "%Y-%m-%d %H:%M:%S"

def slack_sender(webhook_url: str, channel: str, user_mentions: List[str] = [], progress=None, ignore_exceptions=[]):
    """
    Slack sender wrapper: execute func, send a Slack notification with the end status
    (sucessfully finished or crashed) at the end. Also send a Slack notification before
    executing func.

    `webhook_url`: str
        The webhook URL to access your slack room.
        Visit https://api.slack.com/incoming-webhooks#create_a_webhook for more details.
    `channel`: str
        The slack room to log.
    `user_mentions`: List[str] (default=[])
        Optional users ids to notify.
        Visit https://api.slack.com/methods/users.identity for more details.
    """

    dump = {
        "username": "Knock Knock",
        "channel": channel,
        "icon_emoji": ":clapper:",
    }
    def decorator_sender(func):
        @functools.wraps(func)
        def wrapper_sender(*args, **kwargs):
            from ml_logger import logger

            start_time = datetime.datetime.now()
            host_name = socket.gethostname()
            func_name = func.__name__

            # Handling distributed training edge case.
            # In PyTorch, the launch of `torch.distributed.launch` sets up a RANK environment variable for each process.
            # This can be used to detect the master process.
            # See https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py#L211
            # Except for errors, only the master process will send notifications.
            if 'RANK' in os.environ:
                master_process = (int(os.environ['RANK']) == 0)
                host_name += ' - RANK: %s' % os.environ['RANK']
            else:
                master_process = True

            if master_process:
                try:
                    contd = logger.read_params('job.errorTime')
                except KeyError:
                    contd = False
                contents = ['Your training has started 🎬',
                            'Prefix: `%s`' % logger.prefix,
                            'Prev finish time (UTC) %s' % (contd if contd else 'N/A'),
                            'Machine name: %s' % host_name,
                            'Starting date: %s' % start_time.strftime(DATE_FORMAT)]
                contents.append(' '.join(user_mentions))
                dump['text'] = '\n'.join(contents)
                dump['icon_emoji'] = ':clapper:'
                if webhook_url:
                    requests.post(webhook_url, json.dumps(dump))

            try:
                value = func(*args, **kwargs)
                if master_process:
                    end_time = datetime.datetime.now()
                    elapsed_time = end_time - start_time
                    contents = ["Your training is complete 🎉",
                                'Prefix: `%s`' % logger.prefix,
                                'Last step: %s' % progress.step if progress else 'N/A',
                                'Machine name: %s' % host_name,
                                'ml-dash URL: %s' % "http://44.241.150.228:8081/" + logger.prefix,
                                # 'Main call: %s' % func_name,
                                'Starting date: %s' % start_time.strftime(DATE_FORMAT),
                                'End date: %s' % end_time.strftime(DATE_FORMAT),
                                'Training duration: %s' % str(elapsed_time)]

                    # try:
                    #     str_value = str(value)
                    #     contents.append('Main call returned value: %s'% str_value)
                    # except:
                    #     contents.append('Main call returned value: %s'% "ERROR - Couldn't str the returned value.")

                    contents.append(' '.join(user_mentions))
                    dump['text'] = '\n'.join(contents)
                    dump['icon_emoji'] = ':tada:'
                    if webhook_url:
                        requests.post(webhook_url, json.dumps(dump))

                return value

            except Exception as ex:
                assert isinstance(ignore_exceptions, (list, tuple))
                assert all(
                    [issubclass(ignr_ex, Exception) for ignr_ex in ignore_exceptions]
                ), f"Invalid Exception:{ignore_exceptions}"

                if not any(isinstance(ex, ignr_ex) for ignr_ex in ignore_exceptions):
                    end_time = datetime.datetime.now()
                    elapsed_time = end_time - start_time
                    contents = ["Your training has crashed ☠️",
                                'Prefix: `%s`' % logger.prefix,
                                'Last step %s' % progress.step if progress else 'N/A',
                                'Machine name: %s' % host_name,
                                'ml-dash URL: %s' % "http://44.241.150.228:8081/" + logger.prefix,
                                'Main call: %s' % func_name,
                                'Starting date: %s' % start_time.strftime(DATE_FORMAT),
                                'Crash date: %s' % end_time.strftime(DATE_FORMAT),
                                'Crashed training duration: %s\n\n' % str(elapsed_time),
                                "Here's the error:",
                                '%s\n\n' % ex,
                                "Traceback:",
                                '%s' % traceback.format_exc()]
                    contents.append(' '.join(user_mentions))
                    dump['text'] = '\n'.join(contents)
                    dump['icon_emoji'] = ':skull_and_crossbones:'
                    if webhook_url:
                        requests.post(webhook_url, json.dumps(dump))
                raise ex

        return wrapper_sender

    return decorator_sender
