from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import argparse
import re
from bson.json_util import loads

KEYWORDS = ['AAAI', 'Association for the Advancement of Artificial Intelligence', \
            'CIKM', 'Conference on Information and Knowledge Management', \
            'CVPR', 'Conference on Computer Vision and Pattern Recognition', \
            'ECIR', 'European Conference on Information Retrieval', \
            'ECML', 'European Conference on Machine Learning', \
            'EDBT', 'International Conference on Extending Database Technology', \
            'ICDE', 'International Conference on Data Engineering', \
            'ICDM', 'International Conference on Data Mining', \
            'ICML', 'International Conference on Machine Learning', \
            'IJCAI', 'International Joint Conference on Artificial Intelligence', \
            'PAKDD', 'Pacific-Asia Conference on Knowledge Discovery and Data Mining', \
            'PKDD', 'Principles and Practice of Knowledge Discovery in Databases', \
            'KDD', 'Knowledge Discovery and Data Mining', \
            'PODS', 'Principles of Database Systems', \
            'SIGIR', 'Special Interest Group on Information Retrieval', \
            'SIGMOD', 'Special Interest Group on Management of Data', \
            'VLDB', 'Very Large Data Bases', \
            'WWW', 'World Wide Web Conference', \
            'WSDM', 'Web Search and Data Mining', \
            'SDM', 'SIAM International Conference on Data Mining']

CONF2ORG = {
    'AAAI': 'AAAI',
    'CIKM': 'ACM',
    'CVPR': 'IEEE',
    'ECIR': 'Springer',
    'ECML': 'Springer',
    'EDBT': 'Springer',
    'ICDE': 'IEEE',
    'ICDM': 'IEEE',
    'ICML': 'PMLR',
    'IJCAI': 'the IJCAI, Inc.',
    'KDD': 'ACM',
    'PAKDD': 'Springer',
    'PKDD': 'Springer',
    'PODS': 'ACM',
    'SDM': 'SIAM',
    'SIGIR': 'ACM',
    'SIGMOD': 'ACM',
    'VLDB': 'VLDB',
    'WWW': 'ACM',
    'WSDM': 'ACM'
}

LABELS = [
    'Database', 'Data mining', 'Artificial intelligence',
    'Information retrieval'
]

parser = argparse.ArgumentParser()
parser.add_argument('--choice', type=int, default=-1)
parser.add_argument('--input_path', type=str, default='')
parser.add_argument('--output_path', type=str, default='')
args = parser.parse_args()


def extract_considered():
    keywords = [val.lower() for val in KEYWORDS]
    pat = re.compile(r'|'.join(keywords))

    ent = 0
    cnt = 0
    rsvd = 0
    ops = open(args.output_path, 'w')
    try:
        with open(args.input_path, 'r') as ips:
            ele_contents = []
            is_first = True
            for line in ips:
                if is_first:
                    is_first = False
                    continue

                if line[0] == '{':
                    ent += 1
                elif line[0] == '}':
                    ent -= 1

                ele_contents.append(line.strip())

                if ent == 0 and len(ele_contents):
                    json_text = ''.join(ele_contents)
                    json_text = re.sub(r'NumberInt\s*\(\s*(\S+)\s*\)',
                                       r'{"$numberInt": "\1"}', json_text)
                    #print(json_text[:-1])
                    #ele = json.loads(json_text[:-1])
                    if json_text[-1] == ',':
                        ele = loads(json_text[:-1])
                    else:
                        ele = loads(json_text)
                    #if ('venue' in ele and '_id' in ele['venue']) and 'fos' in ele and 'references' in ele:
                    if '_id' in ele and 'venue' in ele and 'raw' in ele[
                            'venue'] and ele['venue']['raw'] and 'fos' in ele and ele[
                                'fos'] and 'references' in ele and 'title' in ele and ele[
                                    'title']:
                        raw_vanue_name = ele['venue']['raw'].lower()
                        if re.search(pat, raw_vanue_name):
                            ops.write("{}\t{}\t{}\t{}\t{}\n".format(
                                ele['_id'], ele['venue']['raw'].replace(
                                    '\n', '').replace('\t', ' '),
                                ele['title'].replace('\n',
                                                     '').replace('\t', ' '),
                                ','.join(ele['fos']).replace('\n', '').replace(
                                    '\t', ' '), ','.join(ele['references'])))
                            rsvd += 1
                    #print(ele)
                    cnt += 1
                    if cnt % 100000 == 0:
                        print(rsvd, cnt, "======>")
                    ele_contents = []
    except Exception as ex:
        print(ex)
    finally:
        ops.close()


"""
{'ICDM': 4589, 'KDD': 5476, 'IJCAI': 7586, 'VLDB': 5314, 'PAKDD': 2242, 'ECIR': 1482, 'ICML': 8322, 'CIKM': 5931, 'WWW': 5553, 'CVPR': 13355, 'EDBT': 1636, 'AAAI': 9695, 'ECML': 2216, 'SIGMOD': 4206, 'ICDE': 4330, 'PODS': 1670, 'SDM': 1624, 'SIGIR': 4619, 'WSDM': 746, 'PKDD': 547}
======================
{'IEEE': 22274, 'ACM': 28201, 'the IJCAI, Inc.': 7586, 'VLDB': 5314, 'Springer': 8123, 'PMLR': 8322, 'AAAI': 9695, 'SIAM': 1624}
"""


def be_canonical():
    keywords = [val.lower() for val in KEYWORDS]
    conf_cnts = dict()
    org_cnts = dict()
    ops = open(args.output_path, 'w')
    with open(args.input_path, 'r') as ips:
        for line in ips:
            num_of_tab = line.count('\t')
            if num_of_tab != 4:
                print(num_of_tab)
                print(line.replace('\t', 'TAB'))
                continue
            cols = line.strip().split('\t')
            conf_raw_name = cols[1].lower()
            org, conf_name = '', ''
            for i, kw in enumerate(keywords):
                if kw in conf_raw_name:
                    conf_name = keywords[i if (i % 2 == 0) else
                                         (i - 1)].upper()
                    org = CONF2ORG[conf_name]
                    break
            if conf_name == '':
                print(cols[1])
                continue
            if conf_name not in conf_cnts:
                conf_cnts[conf_name] = 0
            if org not in org_cnts:
                org_cnts[org] = 0
            conf_cnts[conf_name] += 1
            org_cnts[org] += 1
            ops.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(
                cols[0], conf_name, org, cols[2], cols[3], cols[4]))
    ops.close()

    print(conf_cnts)
    print("======================")
    print(org_cnts)


def be_fourclass_data():
    labels = [val.lower() for val in LABELS]
    cnt = 0
    vset = dict()
    with open(args.input_path, 'r') as ips:
        for line in ips:
            cols = line.strip().split('\t')
            fos = [val.lower() for val in cols[4].split(',')]
            for val in fos:
                if val in labels:
                    cnt += 1
                    vset[cols[0]] = [0, 0]
                    # assume single label or say the classes are exclusive
                    break
    print(cnt)

    e_cnt = 0
    with open(args.input_path, 'r') as ips:
        for line in ips:
            cols = line.strip().split('\t')
            if cols[0] not in vset:
                continue
            refs = cols[-1].split(',')
            for val in refs:
                if val in vset:
                    e_cnt += 1
                    vset[cols[0]][0] += 1
                    vset[val][1] += 1
    print(e_cnt)

    connected = dict([(val, i) for i, val in enumerate(
        [k for k, v in vset.items() if (v[0] > 0 or v[1] > 0)])])
    print(len(connected))

    ops = open(args.output_path, 'w')
    with open(args.input_path, 'r') as ips:
        for line in ips:
            cols = line.strip().split('\t')
            nid = cols[0]
            if nid not in connected:
                continue
            for val in cols[4].split(','):
                can_val = val.lower()
                if can_val in labels:
                    lb = labels.index(can_val)
                    break
            adjs = ','.join([
                str(connected[val]) for val in cols[-1].split(',')
                if val in connected
            ])
            ops.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(
                connected[nid], cols[1], cols[2], cols[3], lb, adjs))
    ops.close()


def stats():
    p2c = dict()
    p2o = dict()
    with open(args.input_path, 'r') as ips:
        for line in ips:
            cols = line.strip().split('\t')
            p2c[cols[0]] = cols[1]
            p2o[cols[0]] = cols[2]

    stats = dict()
    with open(args.input_path, 'r') as ips:
        for line in ips:
            cols = line.strip().split('\t')
            conf = cols[1]
            if conf not in stats:
                stats[conf] = [0, 0, 0, [0, 0, 0, 0]]
            stats[conf][0] += 1
            adjs = cols[-1].split(',')
            for v in adjs:
                if p2c[v] == conf:
                    stats[conf][1] += 1
                else:
                    stats[conf][2] += 1
            lb = int(cols[4])
            stats[conf][3][lb] += 1

    for k, v in stats.items():
        print(k, v)

    stats = dict()
    with open(args.input_path, 'r') as ips:
        for line in ips:
            cols = line.strip().split('\t')
            org = cols[2]
            if org not in stats:
                stats[org] = [0, 0, 0, [0, 0, 0, 0]]
            stats[org][0] += 1
            adjs = cols[-1].split(',')
            for v in adjs:
                if p2o[v] == org:
                    stats[org][1] += 1
                else:
                    stats[org][2] += 1
            lb = int(cols[4])
            stats[org][3][lb] += 1

    for k, v in stats.items():
        print(k, v)


def main():
    if args.choice == 0:
        extract_considered()
    elif args.choice == 1:
        be_canonical()
    elif args.choice == 2:
        be_fourclass_data()
    elif args.choice == 3:
        stats()


if __name__ == "__main__":
    main()
