import json
import time
from typing import Dict, Optional, List

import joblib
from googleapiclient import discovery
from googleapiclient.errors import HttpError

import argparse

from tqdm import tqdm

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--data-path', type=str, default='',
                    help='data path to load the jsonl')
parser.add_argument('--out-path', type=str, default='',
                    help='data path to save the jsonl')
parser.add_argument('--api', type=str, default='',
                    help='api key for Perspective API')
parser.add_argument('--total', type=int, default=-1,
                    help='Total number of data')
parser.add_argument('--workers', type=int, default=1,
                    help='Number of worker processes to launch')


class PerspectiveApiScorer:
    """
    This class provides a method for accessing Perspective API using the Google API Client to obtain attribute scores for generated texts.
    """

    DEFAULT_ATTRIBUTES = ['toxicity', 'severe_toxicity', 'sexually_explicit', 'threat', 'profanity', 'identity_attack']

    def __init__(self, api_key):
        """
        :param api_key: the API key to use. For details, see https://support.perspectiveapi.com/s/docs-get-started
        """
        self._service = discovery.build(
            "commentanalyzer",
            "v1alpha1",
            developerKey=api_key,
            discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1",
            static_discovery=False,
        )

    def get_scores(self, input_text: str, requested_attributes: Optional[List[str]] = None) -> Dict[str, float]:
        """
        Get attribute scores for a given text via Perspective API.
        :param input_text: the input text
        :param requested_attributes: the attributes for which to compute scores
        :return: a mapping from attribute names to scores
        """
        requested_attributes = requested_attributes if requested_attributes else PerspectiveApiScorer.DEFAULT_ATTRIBUTES

        analyze_request = {
            'comment': {'text': input_text},
            'requestedAttributes': {attribute.upper(): {} for attribute in requested_attributes},
            'spanAnnotations': False,
            'languages': ['en'],
        }

        response = None
        while not response:
            try:
                response = self._service.comments().analyze(body=analyze_request).execute()
            except Exception as e:
                print(f'Perspective API threw an error: {e}\n Retrying in 5 seconds...')
                print(input_text)
                time.sleep(1)

        return {attribute: response['attributeScores'][attribute.upper()]['summaryScore']['value'] for attribute in
                requested_attributes}


args = parser.parse_args()


def split_lines(lines, split):
    tot = len(lines)
    each = tot // split
    return [lines[i:i + each] for i in range(0, tot, each)]


from joblib import Parallel, delayed

scorer = PerspectiveApiScorer(args.api)


def get_score(line):
    data = json.loads(line)
    text = data['text']
    text = text.replace("<|endoftext|>", "")
    data['text'] = text
    if not text.strip():
        data['score'] = None
        return json.dumps(data)

    encoded_text = text.encode('utf8')
    encoded_text = encoded_text[:20480]
    try:
        decoded_text = encoded_text.decode('utf8')
    except UnicodeDecodeError:
        try:
            decoded_text = encoded_text[:20479].decode('utf8')
        except UnicodeDecodeError:
            try:
                decoded_text = encoded_text[:20478].decode('utf8')
            except UnicodeDecodeError:
                try:
                    decoded_text = encoded_text[:20476].decode('utf8')
                except:
                    print("Error occurred")
                    data['score'] = None
                    return json.dumps(data)
    data['score'] = scorer.get_scores(decoded_text)
    return json.dumps(data)


def get_scores(lines):
    scorer = PerspectiveApiScorer()
    all_data = []
    for i, line in enumerate(tqdm(lines)):
        data = json.loads(line)
        text = data['text']
        if not text.strip():
            data['score'] = None
            all_data.append(json.dumps(data))
            continue
        encoded_text = text.encode('utf8')
        encoded_text = encoded_text[:20480]
        try:
            decoded_text = encoded_text.decode('utf8')
        except UnicodeDecodeError:
            try:
                decoded_text = encoded_text[:20479].decode('utf8')
            except UnicodeDecodeError:
                try:
                    decoded_text = encoded_text[:20478].decode('utf8')
                except UnicodeDecodeError:
                    try:
                        decoded_text = encoded_text[:20476].decode('utf8')
                    except:
                        print("Error occurred")
                        data['score'] = None
                        all_data.append(json.dumps(data))
                        continue
        data['score'] = scorer.get_scores(decoded_text)
        all_data.append(json.dumps(data))
    return all_data


def get_annotated_datasets(lines, threads=10):
    sub_lines = lines
    splitted_lines = split_lines(sub_lines, threads)
    print(len(sub_lines))
    final = Parallel(n_jobs=threads)(delayed(get_score)(l) for l in splitted_lines)
    import itertools
    finals = list(itertools.chain.from_iterable(final))
    return finals


def main():
    path = args.data_path
    out = args.out_path if args.out_path else path + '-annotated.jsonl'
    print(out)

    fin = open(path, 'r', encoding='utf-8')
    import multiprocessing
    pool = multiprocessing.Pool(args.workers)
    annotated = pool.imap(get_score, fin, 25)
    with open(out, "w") as f:
        if args.total > 0:
            for x in tqdm(annotated, total=args.total):
                f.write(x + '\n')
        else:
            for x in tqdm(annotated):
                f.write(x + '\n')


if __name__ == '__main__':
    main()
