import re
from os import listdir
from os.path import isfile, join
atom_pattern = '([\w-]+\([" \w,\-\' .]+\))'

def load_data_and_index_tuples(inpath, edb_tables, separator, ignore_first_position):
    if ignore_first_position:
        start = 1
    else:
        start = 0
    derived_data = dict()
    files = [file for file in listdir(inpath) if isfile(join(inpath, file))]
    for fname in files:
        print('Processing file {}'.format(fname))
        with open(join(inpath,fname)) as infile:
            pos = fname.find(".csv")
            if pos == -1:
                predicate = fname
            else:
                predicate = fname[:pos]
            if predicate not in edb_tables:     
                facts = list()
                for line in infile:
                    if not line:
                        break  
                    line = line.replace('\n','')
                    arguments = line.split(separator)
                    facts.append(arguments[start:])
                derived_data[predicate] = facts
    return derived_data

def create_atom(predicate, arguments):
    arguments_str = ""
    for index in range(len(arguments)):
        arguments_str = arguments_str + str(arguments[index])
        if index < len(arguments) - 1:
            arguments_str = arguments_str + ","
    return predicate + "(" + arguments_str + ")"

def rewrite_rule_to_triple_format(rule):
    head_atom = get_head_atom(rule)
    head_arguments = getArguments(head_atom)

    body_atoms = get_body_atoms(rule)
        
    new_atoms = list()
    index = 0
    for atom in body_atoms:
        predicate = getPredicate(atom)
        if predicate.startswith("mgc_"):
            new_atoms.append(atom)
        else:
            arity = getArity(atom)
            arguments = getArguments(atom)
            if arity > 1:
                newarguments = arguments
                newarguments.append("S{}".format(index))
                newpredicate = predicate
            else:
                newarguments = [arguments[0], predicate, "S{}".format(index)]
                newpredicate = "typeOf"
            head_arguments.append("S{}".format(index))
            new_atoms.append(create_atom(newpredicate, newarguments))
        index = index + 1
        
    new_body = ""
    for index in range(len(new_atoms)):
        new_body = new_body + new_atoms[index]
        if index < len(new_atoms) - 1:
            new_body = new_body + ", "
    
    head_predicate = head_atom[:head_atom.find("(")]
    new_head = create_atom(head_predicate, head_arguments)    
    return new_head + " :- " + new_body

def list_edb_tables(path:str):
    all_tables = list()
    edbfiles = [file for file in listdir(path) if isfile(join(path, file))]
    for file in edbfiles:
        pos = file.find(".csv")
        predicate = file[:pos]
        all_tables.append(predicate)
    return all_tables

def getArity(atom):
    return atom.count(',') + 1  

def getArguments(atom):
    arity = getArity(atom)
    left = atom.find("(")
    right = atom.find(")")
    if arity == 1:
        arguments = [atom[left + 1:right]]
    else:
        arguments = atom[left + 1:right].split(",")
    new_arguments = []
    for argument in arguments:
        new_arguments.append(argument.strip())
    return new_arguments

def getPredicate(atom):
    return atom[:atom.find("(")]  

def isGround(atom):
    for argument in getArguments(atom):
        if argument[0].isupper():
            return False
    return True
    
def load_rules(filename):
    rules = list()
    with open(filename) as infile:
        for line in infile:
            if not line:
                break
            line = line.replace('\n','')
            rules.append(line)
    return rules      
    
def create_edb_file(paths, outfilename, separator):
    with open(outfilename, 'w') as outfile: 
        index = 0
        for path in paths:
            edbfiles = [file for file in listdir(path) if isfile(join(path, file))]
            for fname in edbfiles:
                pos = fname.find(".csv")
                if pos > 0:
                    predicate = fname[:pos]
                    if separator == ',':
                        outfile.write('EDB{}_predname={}\n'.format(index, predicate))
                        outfile.write('EDB{}_type=INMEMORY\n'.format(index))
                        outfile.write('EDB{}_param0={}\n'.format(index, path))
                        outfile.write('EDB{}_param1={}\n'.format(index, predicate))
                        outfile.write('\n')
                    else:
                        outfile.write('EDB{}_predname={}\n'.format(index, predicate))
                        outfile.write('EDB{}_type=INMEMORY\n'.format(index))
                        outfile.write('EDB{}_param0={}\n'.format(index, path))
                        outfile.write('EDB{}_param1={}\n'.format(index, predicate))
                        outfile.write('EDB{}_param2=t\n'.format(index))
                        outfile.write('\n')
                    index = index + 1

def convert_to_triples_and_write_to_file(path, filename, tables_dict, separator):
    files = [file for file in listdir(path) if isfile(join(path, file))]
    with open(filename, 'w') as fp:
        for fname in files:
            p = fname.split('.')[0]
            arity = tables_dict[p]
            with open(join(path,fname)) as infile:
                for line in infile:
                    if not line:
                        break
                    line = line.replace('\n','')
                    if arity == 1:
                        subj = line
                        obj = p
                        predicate = 'typeOf'
                    else:
                        arguments = line.split(separator)
                        subj = arguments[0]
                        obj = arguments[1]
                        predicate = p
                    fp.write('{}\t{}\t{}\n'.format(subj, predicate, obj))      
                                
def create_tables_dict(rules):
    tables_dict = dict()
    with open(rules) as infile:
        for line in infile:
            if not line:
                break
            atoms = re.findall(atom_pattern, line)
            for atom in atoms:
                index = atom.find("(")
                predicate = atom[:index]
                tables_dict[predicate] = atom.count(',') + 1      
    return tables_dict    

def load_entity_map(entityfileName):
    entity_map = dict()
    with open(entityfileName) as fp:
        while True:
            line = fp.readline()
            if not line:
                break
            line = line.replace('\n','')
            split = line.split("\t")
            identifier = int(split[0])
            entity = split[1]
            entity_map[entity] = identifier 
    return entity_map

def load_relation_map(relationfileName):
    relation_map = dict()
    with open(relationfileName) as fp:
        while True:
            line = fp.readline()
            if not line:
                break
            line = line.replace('\n','')
            split = line.split("\t")
            identifier = int(split[0])
            relation = split[1]
            relation_map[relation] = identifier 
    return relation_map

def getLeaves(dertree:dict):
    parents = dertree.get('parents')
    if parents != None:
        leaves = []
        for parent in parents:
            if not parent.get('fact').startswith('mgc'):
                leaves.extend(getLeaves(parent))
                pass
        return leaves
    else:
        return [dertree.get('fact')]
    
def get_head_atom(rule, sep):
    return rule.split(sep)[0]
    
def get_body_atoms(rule, sep):
    body = rule.split(sep)[1]
    return re.findall(atom_pattern, body)

def all_atoms_are_edbs(edb_dict, atoms):
    for atom in atoms:
        predicate = getPredicate(atom)
        if predicate not in edb_dict:
            return False
    return True
   
def all_atoms_are_idbs(idb_dict, atoms):  
    for atom in atoms:
        predicate = getPredicate(atom)
        if predicate not in idb_dict:
            return False
    return True

def store_atoms_on_disk(path, atoms):
    for atom in atoms:
        predicate = getPredicate(atom)
        arguments = getArguments(atom)
        file = open(join(path,predicate + ".csv"), "a")
        file.write("\t".join(arguments) + "\n")    