import distbnnn
import numpy as np

def gridSearchTournament(fixedParams, nSamples, fastConst, nAllPairs, halfUnif):
    K = fixedParams['K']
    N = fixedParams['N']
    Q = fixedParams['Q']
    T = fixedParams['T']
    print('K {}, N {}, Q {}, T {}'.format(K, N, Q, T))
    distbns = distbnnn.DistbnNN(K, N)
    if halfUnif:
        distbns.setDistbns(distbnnn.generateHalfUnif(K, N))
    else:
        distbns.setDistbns(distbnnn.generateZipf(K, N, 2))
    testIDs = np.random.choice(K, Q, replace=False)
    for S in nSamples:
        for C in fastConst:
            for L in nAllPairs:
                print('S {}, C {}, L {}'.format(S, C, L))
                acc = [0, 0]
                ops = [0, 0]
                for q in range(Q):
                    trueID = testIDs[q]
                    for t in range(T):
                        sample = distbns.sampleDistbn(trueID, S)
                        winner, nOps = distbns.runTournament(sample, fast=False, nAllPairs=L)
                        if winner == trueID:
                            acc[0] += 1
                        ops[0] += nOps
                        winner, nOps = distbns.runTournament(sample, fast=True, fastParam=C, nAllPairs=L)
                        if winner == trueID:
                            acc[1] += 1
                        ops[1] += nOps
                print('SlowTournament: Accuracy {}, AvgOps {}'.format(acc[0]/(Q*T), ops[0]/(Q*T)))
                print('FastTournament: Accuracy {}, AvgOps {}'.format(acc[1]/(Q*T), ops[1]/(Q*T)))
                print()

def testSubset(half_unifs, testIDs, K, N, Q, T, S, L, l):
    print('S {}, L {}, l {}'.format(S, L, l))
    distbns = distbnnn.DistbnNN(K, N, numSubsets=L, subsetSize=l, preprocessScheffe=False)
    distbns.setDistbns(half_unifs)
    acc = [0, 0]
    ops = [0, 0]
    time = [0, 0]
    for q in range(Q):
        trueID = testIDs[q]
        for t in range(T):
            sample = distbns.sampleDistbn(trueID, S)

            winner, nOps, totalTime = distbns.eliminateUnifDistbns(sample, set([x for x in range(K)]))
            if winner == trueID:
                acc[0] += 1
            ops[0] += nOps
            time[0] += totalTime

            winner, nOps, totalTime = distbns.runSubsetAlgo(sample)
            if winner == trueID:
                acc[1] += 1
            ops[1] += nOps
            time[1] += totalTime
    print('EliminationAlgo: Accuracy {:.4f}, AvgOps {:.0f}, AvgSecs {:.4f}'.format(acc[0]/(Q*T), ops[0]/(Q*T), time[0]/(Q*T)))
    print('SubsetAlgo:      Accuracy {:.4f}, AvgOps {:.0f}, AvgSecs {:.4f}'.format(acc[1]/(Q*T), ops[1]/(Q*T), time[1]/(Q*T)))
    print()
    return acc[1]/(Q*T)
 

def gridSearchSubset(fixedParams, nSamples, numSubsets, subsetSize):
    K = fixedParams['K']
    N = fixedParams['N']
    Q = fixedParams['Q']
    T = fixedParams['T']
    print('K {}, N {}, Q {}, T {}'.format(K, N, Q, T))
    half_unifs = distbnnn.generateHalfUnif(K, N)
    testIDs = np.random.choice(K, Q, replace=False)
    for S in nSamples:
        for L in numSubsets:
            for l in subsetSize:
                testSubset(half_unifs, testIDs, K, N, Q, T, S, L ,l)

def minimizeNumSubsets(fixedParams, nSamples, subsetSize, start=1000, growth=1.2, accTol=0.95):
    K = fixedParams['K']
    N = fixedParams['N']
    Q = fixedParams['Q']
    T = fixedParams['T']
    print('K {}, N {}, Q {}, T {}'.format(K, N, Q, T))
    half_unifs = distbnnn.generateHalfUnif(K, N)
    testIDs = np.random.choice(K, Q, replace=False)
    for S in nSamples:
        for l in subsetSize:
            print('Searching for number of subsets')
            L = start
            done = False
            while not done:
                acc = testSubset(half_unifs, testIDs, K, N, Q, T, S, L, l)
                if acc >= accTol:
                    done = True
                else:
                    L = int(L * growth)

            print('Required subsets: {}\n'.format(L))


def tournamentCompare():
    # np.random.seed(1832)
    # fixedParams = {
    #     'K': 8192,
    #     'N': 500,
    #     'Q': 20,
    #     'T': 5
    # }
    fixedParams = {
        'K': 4096,
        'N': 250,
        'Q': 20,
        'T': 5
    }
    # nSamples = [30, 40, 50, 60]
    # nSamples = [30, 40]
    nSamples = [20, 30, 40]
    fastConst = [5, 10, 15]
    nAllPairs = [0, 5, 10, 15, 20]
    # fastConst = [5, 10, 15, 20]
    # nAllPairs = [0, 10, 20, 30]
    gridSearchTournament(fixedParams, nSamples, fastConst, nAllPairs, halfUnif=False)

def subsetCompare():
    fixedParams = {
        'K': 50000,
        'N': 250,
        'Q': 100,
        'T': 1
    }
    nSamples = [60]
    # numSubsets=[5000]
    subsetSize=[3]
    # gridSearchSubset(fixedParams, nSamples, numSubsets, subsetSize)
    minimizeNumSubsets(fixedParams, nSamples, subsetSize, start=200, growth=1.5, accTol=1.00)


if __name__ == '__main__':
    subsetCompare()