# Copyright (c) Alibaba, Inc. and its affiliates.

import logging
import os

import json
import requests

from swift.version import __version__


# 打标
class ModelTag(object):
    _URL = os.environ.get('MODEL_TAG_URL', None)

    # 模型测试结果
    BATCH_COMMIT_RESULT_URL = f'{_URL}/batchCommitResult'
    # 测试阶段完成
    BATCH_REFRESH_STAGE_URL = f'{_URL}/batchRefreshStage'
    # query_model_stage
    QUERY_MODEL_STAGE_URL = f'{_URL}/queryModelStage'

    HEADER = {'Content-Type': 'application/json'}

    # 检测结果
    MODEL_SKIP = 0
    MODEL_FAIL = 1
    MODEL_PASS = 2

    class ItemResult(object):

        def __init__(self):
            self.result = 0
            self.name = ''
            self.info = ''

        def to_json(self):
            return {
                'name': self.name,
                'result': self.result,
                'info': self.info
            }

    def __init__(self):
        self.job_name = ''
        self.job_id = ''
        self.model = ''
        self.sdk_version = ''
        self.image_version = ''
        self.domain = ''
        self.task = ''
        self.source = ''
        self.stage = ''
        # ItemResult list
        self.item_result = []

    # 发送请求
    def _post_request(self, url, param):
        try:
            logging.info(url + ' query: '
                         + str(json.dumps(param, ensure_ascii=False)))
            res = requests.post(
                url=url,
                headers=self.HEADER,
                data=json.dumps(param, ensure_ascii=False).encode('utf8'))
            if res.status_code == 200:
                logging.info(f'{url} post结果: ' + res.text)
                res_json = json.loads(res.text)
                if int(res_json['errorCode']) == 200:
                    return res_json['content']
                else:
                    logging.error(res.text)
            else:
                logging.error(res.text)
        except Exception as e:
            logging.error(e)

        return None

    # 提交模型测试结果
    def batch_commit_result(self):
        try:
            param = {
                'sdkVersion':
                self.sdk_version,
                'imageVersion':
                self.image_version,
                'source':
                self.source,
                'jobName':
                self.job_name,
                'jobId':
                self.job_id,
                'modelList': [{
                    'model': self.model,
                    'domain': self.domain,
                    'task': self.task,
                    'itemResult': self.item_result
                }]
            }
            return self._post_request(self.BATCH_COMMIT_RESULT_URL, param)

        except Exception as e:
            logging.error(e)

        return

    # 测试阶段完成
    def batch_refresh_stage(self):
        try:
            param = {
                'sdkVersion':
                self.sdk_version,
                'imageVersion':
                self.image_version,
                'source':
                self.source,
                'stage':
                self.stage,
                'modelList': [{
                    'model': self.model,
                    'domain': self.domain,
                    'task': self.task
                }]
            }
            return self._post_request(self.BATCH_REFRESH_STAGE_URL, param)

        except Exception as e:
            logging.error(e)

        return

    # 查询模型某个阶段的最新测试结果（只返回单个结果
    def query_model_stage(self):
        try:
            param = {
                'sdkVersion': self.sdk_version,
                'model': self.model,
                'stage': self.stage,
                'imageVersion': self.image_version
            }
            return self._post_request(self.QUERY_MODEL_STAGE_URL, param)

        except Exception as e:
            logging.error(e)

        return None

    # 提交模型UT测试结果
    """
        model_tag = ModelTag()
        model_tag.model = "XXX"
        model_tag.sdk_version = "0.3.7"
        model_tag.domain = "nlp"
        model_tag.task = "word-segmentation"
        item = model_tag.ItemResult()
        item.result = model_tag.MODEL_PASS
        item.name = "ALL"
        item.info = ""
        model_tag.item_result.append(item.to_json())
    """

    def commit_ut_result(self):
        if self._URL is not None and self._URL != '':
            self.job_name = 'UT'
            self.source = 'dev'
            self.stage = 'integration'

            self.batch_commit_result()
            self.batch_refresh_stage()


def commit_model_ut_result(model_name, ut_result):
    model_tag = ModelTag()
    model_tag.model = model_name.replace('damo/', '')
    model_tag.sdk_version = __version__
    # model_tag.domain = ""
    # model_tag.task = ""
    item = model_tag.ItemResult()
    item.result = ut_result
    item.name = 'ALL'
    item.info = ''
    model_tag.item_result.append(item.to_json())
    model_tag.commit_ut_result()
