import math
import numpy as np
import itertools
import subprocess
import sys
import utils
import time
import pandas as pd
import datetime
import argparse

def save_graph(matrices, k, gamma, fname):
    BA = matrices['BA']
    T = matrices['T']
    W = matrices['W']
    n = BA.shape[0]
    with open(fname, 'w') as f:
        f.write(f'{n}\n')
        f.write(f'{k}\n')
        f.write(f'{gamma:.02f}\n')
        for i, j in itertools.product(range(n), range(n)):
            if BA[i, j] == 1:
                if T is None:
                    f.write(f'{i} {j}\n')
                else:
                    t = T[i, j]
                    w = W[i, j]
                    f.write(f'{i} {j} {t} {w}\n')


def save_graph_fraction(matrices, k, gamma, fname, m0):
    # Unipartite setting with stricter requirements for an edge to exist
    frac = 0.8
    B = matrices['bid_matrix']
    A = matrices['author_matrix']
    COI = matrices['coi_matrix']
    n = B.shape[1]
    with open(fname, 'w') as f:
        f.write(f'{n}\n')
        f.write(f'{k}\n')
        f.write(f'{gamma:.02f}\n')
        for i, j in itertools.product(range(n), range(n)):
            b = ((B[:, i] > 0) & (A[:, j] > 0)).sum()
            p = ((COI[:, i] == 0) & (A[:, j] > 0)).sum()
            if p >= m0 and b / p >= frac:
                f.write(f'{i} {j}\n')


def save_graph_bipartite(matrices, k, gamma, fname):
    B = matrices['bid_matrix']
    A = matrices['author_matrix']
    COI = matrices['coi_matrix']
    n = B.shape[1]
    with open(fname, 'w') as f:
        f.write(f'{n}\n')
        f.write(f'{k}\n')
        f.write(f'{gamma:.02f}\n')
        for i, j in itertools.product(range(n), range(n)):
            b = ((B[:, i] > 0) & (A[:, j] > 0)).sum()
            p = ((COI[:, i] == 0) & (A[:, j] > 0)).sum()
            f.write(f'{i} {j} {b} {p}\n')

def count_cliques_c(matrices, k, gamma, dataset, timeout, bipartite, m0=None):
    if bipartite:
        fname = f'_graph_{dataset}_bipartite.txt'
        exename = f"./count_cliques_bipartite.out"
        save_graph_bipartite(matrices, k, gamma, fname)
    elif m0 is not None:
        fname = f'_graph_{dataset}_frac.txt'
        exename = f"./count_cliques_c.out"
        save_graph_fraction(matrices, k, gamma, fname, m0)
    else:
        fname = f'_graph_{dataset}.txt'
        exename = f"./count_cliques_c.out"
        save_graph(matrices, k, gamma, fname)
    try:
        result = subprocess.run([exename, fname], stdout=subprocess.PIPE, timeout=(timeout*60))
        output = result.stdout.decode('utf-8').strip()
        lines = output.split('\n')
        x = int(lines[0])
        hist = [int(y) for y in lines[1:]]
        assert sum(hist) == x
    except subprocess.TimeoutExpired as e:
        print(f'Timeout({e.cmd})')
        x = -1
        hist = []
    return x, hist



def run(dataset, param_list, timeout, bipartite, m0=None):
    frontier = False
    if param_list == 'frontier':
        frontier = True 
        param_list = [(1.0, 2)]

    if bipartite or (m0 is not None):
        matrices = utils.get_author_only_matrices(dataset)
    else:
        matrices = {m : M for m, M in zip(['BA', 'T', 'W'], utils.make_BA(dataset, return_text_weights=True, authors_only=True))}

    datestring = datetime.datetime.now().isoformat() 
    fname = f'results/clique_results_{dataset}' + ('_bipartite' if bipartite else '') + f'_{datestring}.csv'
    print(fname)
    i = 0
    param_list = list(param_list)
    results = []
    infeasible_set = set()
    while i < len(param_list):
        gamma, k = param_list[i]
        if any([gamma <= g_inf and k >= k_inf for (g_inf, k_inf) in infeasible_set]):
            print(f'({k}, {gamma}) : skipping due to known infeasibility')
            i += 1
            continue
        else:
            print(f'({k}, {gamma}) : {datetime.datetime.now()}')
        t0 = time.time()
        c, hist = count_cliques_c(matrices, k, gamma, dataset, timeout=timeout, bipartite=bipartite, m0=m0)
        t1 = time.time()
        mins = (t1-t0)/60
        if bipartite:
            assert c == -1 or len(hist) == 101
            if c != -1:
                sums = np.cumsum(hist[::-1])[::-1]
                for bucket, (gamma_count, cumulative_count) in enumerate(zip(hist, sums)):
                    g = bucket / 100
                    results.append((k, g, -1, cumulative_count, dataset, mins))
                for j, count in enumerate(sums[10::10]):
                    g = (j+1) / 10
                    print(f'({k}, {g}) \t {count} \t ({mins:.02f})')
        else:
            print(f'({k}, {gamma}) \t {c} \t ({mins:.02f})')
            results.append((k, gamma, -1, c, dataset, mins))
            if len(hist) > 0: # for counts split on text-similarity
                for b, h in enumerate(hist):
                    t = b / 100
                    results.append((k, gamma, t, h, dataset, mins))
        df = pd.DataFrame.from_records(results, columns=['k', 'gamma', 'text_sim', 'num_cliques', 'dataset', 'time'])
        df.to_csv(fname)

        if frontier:
            if (c == -1) or (gamma == 0.1):
                return df
            elif c == 0:
                param_list.append((np.round(gamma - 0.1, 1), k))
            else:
                param_list.append((gamma, k + 1))
        if c == -1:
            infeasible_set.add((gamma, k))
        i += 1
    return df


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('dataset')
    parser.add_argument('-t', '--time_limit', type=int, default=1440, help='Time limit in minutes')
    parser.add_argument('-kn', '--k_min', type=int, default=2)
    parser.add_argument('-kx', '--k_max', type=int, default=None)
    parser.add_argument('-gn', '--gamma_min', type=float, default=0.5)
    parser.add_argument('-gx', '--gamma_max', type=float, default=1)
    parser.add_argument('-gs', '--gamma_step', type=float, default=0.1)
    parser.add_argument('-bp', '--bipartite', action='store_true')
    parser.add_argument('-f', '--frontier', action='store_true', help='Trace the frontier of feasible params')
    parser.add_argument('-m', '--min_possible', type=int, default=None, help='Parameter for fractional unipartite setting')

    args = parser.parse_args()
    assert args.dataset in ['aamas_sub3', 'wu']
    if args.bipartite:
        args.gamma_min = 1
        args.gamma_max = 1
        args.gamma_step = 0
    if args.k_max is None:
        if args.bipartite:
            args.k_max = 7
        else:
            args.k_max = 11
    
    if args.frontier:
        param_list = 'frontier'
    else:
        param_list = utils.make_param_list(args.k_min, args.k_max, args.gamma_min, args.gamma_max, args.gamma_step)
    run(args.dataset, param_list, args.time_limit, args.bipartite, m0=args.min_possible)


if __name__ == '__main__':
    main()
