import time
from scipy.optimize import least_squares, lsq_linear, shgo
import cvxpy
from clusterer import *

class ClusterLabelPredictor:
    def __init__(self, n_clusters, examples):
        self.n_clusters = n_clusters
        self.examples = examples
    def isResultUnique(self, result):
        for i in range(10):
            if i not in result:
                return False
        return True

    def evaluateExample(self, ex, cluster_ids, cluster_labels):
        value = -ex[1]
        imageIds = []
        imageLabels = {}
        for n in range(len(ex[2])):
            for m in range(len(ex[2][n])):
                imageId = ex[2][n][m]
                if imageId is not None:
                    imageIds.append(imageId)
                    clusterId = cluster_ids[imageId]
                    v = cluster_labels[clusterId]
                    imageLabels[imageId] = v
                    value += pow(10, len(ex[2][n]) - m - 1) * v
        return value, imageIds, imageLabels

    def fix_cluster_labels(self, cluster_ids, result):
        bincount = np.bincount(np.array(result))

        poss = []
        for i in result:
            p = []
            if i == 0:
                p += [0, 1]
            elif i == 9:
                p += [8, 9]
            else:
                p += [i-1, i, i+1]

            if i < len(bincount) and bincount[i] > 1:
                if i == 0 or i == 1:
                    p += [i+2]
                elif i == 8 or i == 9:
                    p += [i - 2]
                else:
                    p += [i - 2, i + 2]

            poss.append(p)


        combs = np.array(np.meshgrid(*poss)).T.reshape(-1, self.n_clusters)
        # print(len(comb_array))
        unique_combs = [result]
        unique_combs_hash = [hash(tuple(result))]
        for comb in combs:
            comb_hash = hash(tuple(comb))
            if self.isResultUnique(comb) and comb_hash not in unique_combs_hash:
                unique_combs.append(comb)
                unique_combs_hash.append(comb_hash)

        # print('unique combos: {}'.format(len(unique_combs)))

        if len(unique_combs) == 0:
            return result, 0

        best = None
        best_loss = 0
        for unique_comb in unique_combs:
            loss = 0
            for ex in self.examples:
                value, imageIds, imageLabels = self.evaluateExample(ex, cluster_ids, unique_comb)
                loss += abs(value)
            # print(unique_comb, loss)
            if best is None or loss < best_loss:
                best = unique_comb
                best_loss = loss

        return list(best), best_loss

class ClusterLabelPredictorCVXPYMajority(ClusterLabelPredictor):
    def __init__(self, n_clusters, examples):
        super().__init__(n_clusters, examples)

    def objective(self, examples, cluster_ids):
        A = []
        y = []

        num_digits = 0
        for ex in examples:
            varsN = []
            for i in range(len(ex[2])):
                varsI = []
                for idx in ex[2][i]:
                    if idx is not None:
                        varsI.append(cluster_ids[idx])
                    else:
                        varsI.append(None)
                varsN.append(varsI)
                num_digits = len(varsI)

            coeff = {}
            for i in range(len(ex[2])):
                for j, v in enumerate(varsN[i]):
                    if v is not None:
                        p = len(varsN[i]) - j - 1
                        if v not in coeff:
                            coeff[v] = 0
                        coeff[v] += pow(10, p)

            exp = [0 for _ in range(self.n_clusters)]
            for v in coeff:
                # for s in range(10):
                #     exp[10*v+s] = s * (coeff[v] / pow(10, num_digits/2))
                exp[v] += coeff[v] / pow(10, num_digits / 2)
            A.append(exp)
            y.append(ex[1] / pow(10, num_digits/2))

        return np.array(A), np.array(y)

    def predict_cluster_labels_once(self, cluster_ids, examples):
        start_time = time.time()
        # declare the integer-valued optimization variable
        x = cvxpy.Variable(self.n_clusters, integer=True)

        A, y = self.objective(examples, cluster_ids)

        # set up the L1-norm minimization problem
        obj = cvxpy.Minimize(cvxpy.norm(A @ x - y, 1))

        constraints = [
            x >= 0,
            x <= 9
        ]

        prob = cvxpy.Problem(obj, constraints)

        # solve the problem using an appropriate solver
        sol = prob.solve(solver='GLPK_MI')

        # the optimal value of x is
        return [int(np.round(i,0)) for i in x.value], prob.value, time.time() - start_time

    def predict_cluster_labels(self, cluster_ids):
        reps = 10
        chunk_size = len(self.examples) // reps
        results = {}
        total_time_taken = 0
        for i in range(reps):
            # chunk_exampleIds = exampleIds[i * chunk_size:(i + 1) * chunk_size] if i < reps - 1 else exampleIds[i * chunk_size:]
            chunk_examples = self.examples[i * chunk_size:(i + 1) * chunk_size] if i < reps - 1 else self.examples[i * chunk_size:]
            result, cost, time_taken = self.predict_cluster_labels_once(cluster_ids, chunk_examples)
            str_result = ','.join(str(x) for x in result)
            if str_result not in results:
                results[str_result] = {'count': 0, 'cost': cost, 'time_taken' : time_taken, 'num_sat_examples':0}
            results[str_result]['count'] += 1
            results[str_result]['cost'] += cost
            total_time_taken += time_taken

        # best_result = None
        # for str_result in results:
        #     if best_result is None or results[str_result]['cost']/results[str_result]['count'] < best_result['cost']/best_result['count']:
        #         best_result = {
        #             'count': results[str_result]['count'],
        #             'cost': results[str_result]['cost'],
        #             'result': [int(x) for x in str_result.split(',')]
        #         }
        for str_result in results:
            for ex in self.examples:
                value, _, _ = self.evaluateExample(ex, cluster_ids, [int(x) for x in str_result.split(',')])
                if value == 0:
                    results[str_result]['num_sat_examples'] += 1

        best_result = None
        for str_result in results:
            if best_result is None or results[str_result]['num_sat_examples'] > best_result['num_sat_examples']:
                best_result = {
                    'count': results[str_result]['count'],
                    'cost': results[str_result]['cost'],
                    'num_sat_examples': results[str_result]['num_sat_examples'],
                    'result': [int(x) for x in str_result.split(',')]
                }

        # start_time = time.time()
        # result, loss = self.fix_cluster_labels(cluster_ids, best_result['result'], exampleIds)
        # total_time_taken += time.time() - start_time

        return best_result['result'], best_result['num_sat_examples'], total_time_taken
