import random
import argparse



def make_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cnf", type=str)
    parser.add_argument("--core", type=str)
    parser.add_argument("--save", type=str)
    parser.add_argument("--origin", type=str, default=None)
    parser.add_argument("--drat", type=str, default=None)
    parser.add_argument("--add_var", action="store_true", default=False)
    return parser


def clause_in_core(core_clause, origin_content):
    for clause in origin_content:
        if set(core_clause.split(' ')) == set(clause.split(' ')):
            return True
    return False


def update_clause(new_clause, old_clause, content):
    # delete clause
    if new_clause == '-':
        for idx, clause in enumerate(content):
            if set(old_clause.split(' ')) == set(clause.split(' ')):
                break
        content[idx] = content[-1]
        content = content[:-1]
        return content
        
    for idx, clause in enumerate(content):
        if set(old_clause.split(' ')) == set(clause.split(' ')):
            content[idx] = new_clause
            return content


def read_files(args):
    with open(args.cnf) as cnf:
        content = cnf.readlines()
        while content[0].split()[0] == 'c':
            content = content[1:]
        num_vars = int(content[0].split(' ')[2])
        while content[0].split()[0] == 'p':
            content = content[1:]
        while len(content[-1].split()) <= 1:
            content = content[:-1]
    
    with open(args.core, encoding='windows-1252') as core:
        core_content = core.readlines()
        while core_content[0].split()[0] == 'c' or core_content[0].split()[0] == 'p':
            core_content = core_content[1:]
        while len(core_content[-1].split()) <= 1:
            core_content = core_content[:-1]
    
    if args.origin == None:
        return content, num_vars, core_content, None, None
    
    with open(args.origin) as origin:
        origin_content = origin.readlines()
        while origin_content[0].split()[0] == 'c' or origin_content[0].split()[0] == 'p':
            origin_content = origin_content[1:]
        while len(origin_content[-1].split()) <= 1:
            origin_content = origin_content[:-1]

    if args.drat == None:
        return content, num_vars, core_content, origin_content, None
    
    with open(args.drat) as drat:
        drat_content = drat.readlines()
        drat_content = drat_content[:-1]
    
    return content, num_vars, core_content, origin_content, drat_content





def remove_core(args):
    content, num_vars, core_content, origin_content, _ = read_files(args)

    for core_clause in core_content:
        origin_flag = False
        for origin_clause in origin_content:
            if set(core_clause.split(' ')) == set(origin_clause.split(' ')):
                origin_content.remove(origin_clause)
                origin_flag = True
                break
        if origin_flag:
            continue
        for clause in content:
            if set(core_clause.split(' ')) == set(clause.split(' ')):
                content.remove(clause)
                break
    
    with open(args.save, 'w') as out_file:
        out_file.write("c generated by G2SAT lcg\n")
        out_file.write("p cnf {} {}\n".format(num_vars, len(content)))
        for clause in content:
            out_file.write(clause)

    return


def add_variables(args):
    content, num_vars, core_content, origin_content, _ = read_files(args)
    
    core_vars = [x for x in range(1, num_vars+1)]
    print(F'BEFORE REMOVE, len(core_vars)={len(core_vars)}')
    for core_clause in core_content:
        core_lits = core_clause.split(' ')[:-1]
        for lit in core_lits:
            var = abs(int(lit))
            if var in core_vars:
                core_vars.remove(var)
    print(F'AFTER REMOVE, len(core_vars)={len(core_vars)}')
                

    for core_clause in core_content:
        origin_flag = False
        for origin_clause in origin_content:
            if set(core_clause.split(' ')) == set(origin_clause.split(' ')):
                origin_content.remove(origin_clause)
                origin_flag = True
                break
        if origin_flag:
            continue
        for idx, clause in enumerate(content):
            if set(core_clause.split(' ')) == set(clause.split(' ')):
                # ! difference between this and remove_core():
                # do_operation = random.randint(0, 1)
                do_operation = 1
                rm_not_add = random.randint(0, 1)
                # rm_not_add = 1
                if do_operation and rm_not_add:
                    var = random.choice(core_vars)
                    core_vars.remove(var)
                    sign = random.randint(0, 1)
                    lit = var if sign else -var
                    content[idx] = f'{lit} ' + clause
                    print(f'add vars {lit} in {content[idx]}')
                if do_operation and not rm_not_add and len(clause.split(' ')) > 2:
                    varidx = random.randint(0, len(clause.split(' '))-2)
                    var = clause.split(' ')[varidx]
                    var = var + ' '
                    varidx = clause.find(var)
                    print(f'remove vars {var}in {content[idx]}')
                    content[idx] = clause[:varidx] + clause[varidx + len(var):]
                break
    
    with open(args.save, 'w') as out_file:
        out_file.write("c generated by G2SAT lcg\n")
        out_file.write("p cnf {} {}\n".format(num_vars, len(content)))
        for clause in content:
            out_file.write(clause)

    return


def merge_clauses(args):
    content, num_vars, core_content, origin_content, _ = read_files(args)
                
    merge_flag = False
    random.shuffle(core_content)
    for core_clause in core_content:
        if merge_flag:
            break
        origin_flag = False
        for origin_clause in origin_content:
            if set(core_clause.split(' ')) == set(origin_clause.split(' ')):
                origin_content.remove(origin_clause)
                origin_flag = True
                break
        if origin_flag:
            continue
        for idx, clause in enumerate(content):
            if set(core_clause.split(' ')) == set(clause.split(' ')):
                tmp_content = content.copy()
                random.shuffle(tmp_content)
                for clause2 in tmp_content:
                    if clause2 == clause:
                        continue
                    set1 = set([abs(int(x)) for x in clause.split(' ')])
                    set2 = set([abs(int(x)) for x in clause2.split(' ')])
                    if set1 | set2 == set1 or set1 | set2 == set2:
                        tmp_set = set([x for x in clause.split(' ')[:-1]]) | set([x for x in clause2.split(' ')[:-1]])
                        content[idx] = ' '.join(list(tmp_set)) + ' 0\n'
                        print(f'change: {content[idx]}')
                        merge_flag = True
                        break
    if merge_flag == False:
        print(f'Not change')
    with open(args.save, 'w') as out_file:
        out_file.write("c generated by G2SAT lcg\n")
        out_file.write("p cnf {} {}\n".format(num_vars, len(content)))
        for clause in content:
            out_file.write(clause)

    return



def loose_core(args):
    content, num_vars, core_content, origin_content, _ = read_files(args)

    core_lits = set()
    not_core_lits = set()
    for core_clause in core_content:
        origin_flag = False
        for origin_clause in origin_content:
            if set(core_clause.split(' ')) == set(origin_clause.split(' ')):
                origin_content.remove(origin_clause)
                core_lits = core_lits | set([int(x) for x in core_clause.split(' ')[:-1]])
                origin_flag = True
                break
        if origin_flag:
            continue

        not_core_lits = set([int(x) for x in core_clause.split(' ')[:-1]])
    
    not_core_lits = not_core_lits - core_lits
    
    if len(not_core_lits) == 0:
        print(f'No variable in small core could be flip, failed.')
        return
    # selected_lit = random.choice(list(not_core_lits))
    for selected_lit in not_core_lits:
        print(f'selected var: {selected_lit}')
        # flip selected var in core of content
        for core_clause in core_content:
            clause_lits = [int(x) for x in core_clause.split(' ')[:-1]]
            if selected_lit in clause_lits:
                for idx, clause in enumerate(content):
                    # if selected_lit in [int(x) for x in clause.split(' ')[:-1]]:
                    if set(core_clause.split(' ')) == set(clause.split(' ')):
                        clause_lits.remove(selected_lit)
                        clause_lits.append(-selected_lit)
                        content[idx] = ' '.join([str(x) for x in clause_lits]) + ' 0\n'
                        break
    
    with open(args.save, 'w') as out_file:
        out_file.write("c generated by G2SAT lcg\n")
        out_file.write("p cnf {} {}\n".format(num_vars, len(content)))
        for clause in content:
            out_file.write(clause)

    return



def remove_partial_core(args):
    content, num_vars, core_content, origin_content, _ = read_files(args)

    random.shuffle(core_content)
    for core_clause in core_content:
        origin_flag = False
        for origin_clause in origin_content:
            if set(core_clause.split(' ')) == set(origin_clause.split(' ')):
                origin_content.remove(origin_clause)
                origin_flag = True
                break
        if origin_flag:
            continue
        origin_flag = False
        for clause in content:
            if set(core_clause.split(' ')) == set(clause.split(' ')):
                content.remove(clause)
                origin_flag = True
                break
        if origin_flag:
            break
    
    with open(args.save, 'w') as out_file:
        out_file.write("c generated by G2SAT lcg\n")
        out_file.write("p cnf {} {}\n".format(num_vars, len(content)))
        for clause in content:
            out_file.write(clause)

    return



def trace_back(args, delete=False):
    content, num_vars, core_content, origin_content, drat_content = read_files(args)

    for drat_lit in drat_content:
        if drat_lit[0] == 'd':
            continue
        drat_lit = [x for x in drat_lit.split(' ')][0]
        neg_drat_lit = str(-int(drat_lit))
        
        flip_flag = False
        for core_clause in core_content:
            if neg_drat_lit not in core_clause.split(' ')[:-1]:
                continue
            if clause_in_core(core_clause, origin_content):
                continue
            
            if delete:
                content = update_clause('-', core_clause, content)
                flip_flag = True
                break
            else:
                fliped_clause = core_clause.split(' ')[:-1]
                fliped_clause.remove(neg_drat_lit)
                fliped_clause.append(drat_lit)
                fliped_clause = ' '.join(fliped_clause) + ' 0\n'
                content = update_clause(fliped_clause, core_clause, content)
                flip_flag = True
        
        if flip_flag:
            break
        
    if not flip_flag:
        print(f'No lit could be flip!')
        return
    
    with open(args.save, 'w') as out_file:
        out_file.write("c generated by G2SAT lcg\n")
        out_file.write("p cnf {} {}\n".format(num_vars, len(content)))
        for clause in content:
            out_file.write(clause)

    return



def attach_new_lit(args):
    content, num_vars, core_content, origin_content, _ = read_files(args)

    random.shuffle(core_content)
    for core_clause in core_content:
        origin_flag = False
        for origin_clause in origin_content:
            if set(core_clause.split(' ')) == set(origin_clause.split(' ')):
                origin_content.remove(origin_clause)
                origin_flag = True
                break
        if origin_flag:
            # print('====================================================================================')
            continue
        
        origin_flag = False
        for idx, clause in enumerate(content):
            if set(core_clause.split(' ')) == set(clause.split(' ')):
                if args.add_var:
                    num_vars += 1
                clause = f"{num_vars} " + clause
                content[idx] = clause
                origin_flag = True
                break
        if origin_flag:
            break
    
    with open(args.save, 'w') as out_file:
        # print("writing to", args.save)
        out_file.write("c generated by G2SAT lcg\n")
        out_file.write("p cnf {} {}\n".format(num_vars, len(content)))
        for clause in content:
            out_file.write(clause)

    return



def loose_big_core(args):
    content, num_vars, core_content, _, _ = read_files(args)

    random.shuffle(core_content)
    for core_clause in core_content:
        origin_flag = False
        for idx, clause in enumerate(content):
            if set(core_clause.split(' ')) == set(clause.split(' ')):
                if args.add_var:
                    num_vars += 1
                clause = f"{num_vars} " + clause
                content[idx] = clause
                origin_flag = True
                break
        if origin_flag:
            break
    
    with open(args.save, 'w') as out_file:
        out_file.write("c generated by G2SAT lcg\n")
        out_file.write("p cnf {} {}\n".format(num_vars, len(content)))
        for clause in content:
            out_file.write(clause)

    return



if __name__ == "__main__":
    args = make_arguments().parse_args()
    attach_new_lit(args)