from dataset.dataset import Dataset
from dataset.domain import Domain
from dataset.workloads import downward_closure, downward_closure_strict
from utils import *

from algebra import Workload, ResidualWorkload, ResidualWorkload2, MarginalWorkload, VStack, HStack, Block
from linear_operator.operators import KroneckerProductLinearOperator, IdentityLinearOperator, DenseLinearOperator, CatLinearOperator
import torch
import numpy as np
import scipy

from tqdm import tqdm
import cvxpy as cp

def updateCandidateSet(candidate_type, candidates, new, M, target_closure, return_new = False):
    if candidate_type == 'frontier':
        return updateCandidateSetFrontier(candidates, new, M, target_closure, remove = True, return_new = return_new)
    else:
        return updateCandidateSetFrontier(candidates, new, M, target_closure, remove = False, return_new = return_new)
    
def updateCandidateSetFrontier(candidates, new, M, target_closure, remove, return_new):
    M_idx = set([wkload.cols for wkload in M.workloads])
    if remove == True:
        candidates.remove(new)
    new_idx = [tup for tup in target_closure if (len(tup) == len(new) + 1 and set(downward_closure_strict([tup])).issubset(M_idx)) and not tup in M_idx and not tup in candidates]
    if return_new:
        return list(set(candidates + new_idx)), list(set(new_idx))
    else:
        return list(set(candidates + new_idx))

def upwardClosure(candidate, target_closure):
    return len([tup for tup in target_closure if set(candidate).issubset(set(tup))])

def exponential(R, scores, sensitivity, epsilon):
    probabilities = scipy.special.softmax((0.5*epsilon/sensitivity)*scores)
    # print(probabilities.round(5))
    # print((probabilities.min(), probabilities.mean(), probabilities.max()))
    index = np.random.choice(range(len(R)), 1, p=probabilities)[0]
    return R[index]

def getMarginalWorkload(indices, domain):
    return Workload([IdentityLinearOperator(domain[col]) 
                       if col in indices else DenseLinearOperator(torch.ones(1, domain[col])) 
                       for col in domain.attrs])

def scoreWorstResidual(candidate, data, M_plus, y, norm = 1):
    true = ResidualWorkload2(candidate, data.domain).getAnswers(data, sigma = 0)
    QM_plus = ResidualWorkload2(candidate, data.domain) @ M_plus
    inferred = QM_plus @ y
    return torch.linalg.norm(true - inferred, norm).item()

def scoreWorstMarginal(candidate, data, M_plus, y, norm = 1):
    wkload = MarginalWorkload(candidate, data.domain)
    true = wkload.getAnswers(data, sigma = 0)
    wkload_Mplus = wkload @ M_plus
    inferred = wkload_Mplus @ y
    return torch.linalg.norm(true - inferred, norm).item()

def scoreExpectedResidualL1(candidate, data, M_plus, y, rho, sigmas):
    # print(candidate)
    proposed_sigma = sigma(candidate, data.domain, rho)
    if candidate in sigmas.keys():
        current_sigma = sigmas[candidate]
        new_sigma = ((current_sigma ** -2) + (proposed_sigma ** -2)) ** -0.5
    else:
        new_sigma = proposed_sigma
    if candidate == ():
        expected_error = new_sigma * ((2/np.pi) ** 0.5)
    else:
        expected_error = new_sigma * ((2/np.pi) ** 0.5) * np.prod([(2 ** 0.5) * (data.domain[col] - 1) for col in candidate])
    return scoreWorstResidual(candidate, data, M_plus, y, norm = 1) - expected_error

def sensitivityWorstResidual(candidates, domain):
    return np.max([ResidualWorkload(cand, domain).shape[0] for cand in candidates])

def sensitivityWorstResidual2(candidates):
    return np.max([2 ** len(cand) for cand in candidates])

def scoreFunc(candidate, data, M_plus, y, score, rho, sigmas):
    if score == 'worst_residual':
        return scoreWorstResidual(candidate, data, M_plus, y)
    if score == 'worst_marginal':
        return scoreWorstMarginal(candidate, data, M_plus, y)
    # if score == 'expected_marginal':
    #     return scoreExpectedMarginal(candidate, data, M, y, rho)
    if score == 'expected_residual_L1':
        return scoreExpectedResidualL1(candidate, data, M_plus, y, rho, sigmas)
    else:
        raise Exception('score not supported')
        
def sensitivityFunc(candidates, domain, score):
    if score == 'worst_residual':
        return sensitivityWorstResidual2(candidates)
    if score == 'worst_marginal':
        return 1.0
    # if score == 'expected_marginal':
    #     return 1.0
    if score == 'expected_residual_L1':
        return sensitivityWorstResidual2(candidates)
    else:
        raise Exception('score not supported')
        
def weightFunc(weights, candidate, target_closure):
    if weights == 'uniform':
        return 1.0
    if weights == 'upward_closure':
        return upwardClosure(candidate, target_closure)
    else:
        raise Exception('weights not supported')
        
def initialCandidates(candidate_type, M_idx, target_closure):
    if candidate_type == 'all_remeasure':
        return [tup for tup in target_closure if set(downward_closure_strict([tup])).issubset(M_idx)]
    else:
        return [tup for tup in target_closure if (len(tup) == 2 and set(downward_closure_strict([tup])).issubset(M_idx))]

    

class ddrp:
    def __init__(self, target_marginals, rho, rounds, scoreFunction, 
                 gamma = 0.1, alpha = 0.5, weights = 'uniform', candidate_type = 'all_remeasure', initialization = 1, annealing = False):
        """
        Instantiates ddrp mechanism
        :param target_marginals: list of tuples of indices
        :param rho: scalar; privacy budget
        :param rounds: int
        :param scoreFunction: 'worst_residual', 'worst_marginal', or 'expected_residual_L1'
        :param gamma: scalar; proportion of budget spent on initialization
        :param alpha: scalar; proportion of budget spent on selection each round
        :param weights: 'uniform', 'upward_closure'
        :param candidates: 'frontier', 'frontier_remeasure', 'all_remeasure'
        :param initialization: 1, 2, 3, ...
        :param annealing: boolean
        """
        
        self.target_marginals = target_marginals
        #self.target_workload = VStack([getMarginalWorkload(tup) for tup in self.target_marginals])
        self.target_closure = downward_closure(self.target_marginals)
        self.rounds = rounds
        self.rho = rho
        self.gamma = gamma
        self.rho_init = self.rho * self.gamma
        self.rho_round = (self.rho - self.rho_init) / self.rounds
        self.rho_used = 0
        self.alpha = alpha
        self.score = scoreFunction
        self.weights = weights
        self.candidate_type = candidate_type
        self.initialization = initialization
        self.annealing = annealing
        
        
    def run(self, data):
        ## get all 0D & 1D residual queries 
        init_idx = [tup for tup in self.target_closure if len(tup) < self.initialization]
        
        ## create M, y, sigmas
        M = VStack([ResidualWorkload2(tup, data.domain) for tup in init_idx])
        y, init_sigmas = M.getAnswers(data, rho = self.rho_init/len(M.workloads), return_sigma = True)
        sigmas = {init_idx[i] : init_sigmas[i] for i in range(len(init_idx))}
        self.rho_used += self.rho_init
        # print(self.calcTargetWorkloadError(data, M, y))
        
        ## initial candidates and scores
        M_idx = set([wkload.cols for wkload in M.workloads])
        C_idx = initialCandidates(self.candidate_type, 
                                  M_idx, 
                                  self.target_closure)
        # print('Candidates:' + str(C_idx)) 
        
        for t in tqdm(range(self.rounds)):
            # get scores and sensitivity
            weights = { cand : weightFunc(self.weights, cand, self.target_closure) for cand in C_idx }
            M_plus = M.pinv()
            scores = np.array([weights[cand] * scoreFunc(cand, data, M_plus, y, self.score, (1 - self.alpha) * self.rho_round, sigmas) for cand in C_idx])
            sensitivity = max(weights.values()) * sensitivityFunc(C_idx, data.domain, self.score)
            # print(C_idx)
            # print(scores.round(5))
            
            # run exp mechanism and measure selected workload
            c_star = exponential(C_idx, scores, sensitivity, (self.alpha * self.rho_round * 8) ** 0.5)
            # print('Measuring:' + str(c_star))
            c_star_wkload = ResidualWorkload2(c_star, data.domain)
            c_star_answer, c_star_sigma = c_star_wkload.getAnswers(data, rho = (1 - self.alpha) * self.rho_round, return_sigma = True)
            # print(c_star_answer)
            self.rho_used += self.rho_round
            
            if self.annealing:
                c_star_model_old = (c_star_wkload @ M.pinv()) @ y
            
            ## update M, y, candidates, sigmas
            if c_star in sigmas.keys():
                c_star_index = [wkload.cols for wkload in M.workloads].index(c_star)
                # print('old y:' + str(y[c_star_index]))
                # print('old sigma:' + str(sigmas[c_star]))
                new_y_num = y[c_star_index]/(sigmas[c_star] ** 2) + c_star_answer/ (c_star_sigma ** 2)
                new_y_den = (sigmas[c_star] ** -2) + (c_star_sigma ** -2)
                y[c_star_index] = new_y_num/new_y_den
                sigmas[c_star] = new_y_den ** -0.5
                # print('new y:' + str(y[c_star_index]))
                # print('new sigmas:' + str(sigmas[c_star]))
            else:
                M = M.append(c_star_wkload)
                y.append(c_star_answer)
                sigmas[c_star] = c_star_sigma
            C_idx = updateCandidateSet(candidates = C_idx, 
                                       new = c_star, 
                                       M = M, 
                                       target_closure = self.target_closure,
                                       candidate_type = self.candidate_type)
            # print('Candidates:' + str(C_idx))
            # print('# of Measurements:' + str(len(M.workloads)))
            # print('# of Candidates:' + str(len(C_idx)))
            
            # print(self.calcTargetWorkloadError(data, M, y))
            
            if self.annealing:
                if self.rho - self.rho_used == 0:
                    break
                c_star_model_new = (c_star_wkload @ M.pinv()) @ y
                expected_improvement = sigmas[c_star] * ((2/np.pi) ** 0.5) * max(1, np.prod([(2 ** 0.5) * (data.domain[col] - 1) for col in c_star]))
                if torch.linalg.norm(c_star_model_new - c_star_model_old, 1).item() < expected_improvement:
                    # print('Increasing budget')
                    self.rho_round = 2 * self.rho_round
                if self.rho - self.rho_used < 2 * self.rho_round:
                    self.rho_round = self.rho - self.rho_used
            
        return M, y
    
    def getTargetWorkload(self, data):
        return VStack([MarginalWorkload(tup, data.domain) for tup in self.target_marginals])
    
    def calcTargetWorkloadError(self, data, M, y):
        Q = self.getTargetWorkload(data)
        QM_plus = Q @ M.pinv()
        inferred = QM_plus @ y
        true_answers = Q.getAnswers(data, sigma = 0)
        errors = [torch.linalg.vector_norm((inferred[idx] - true_answers[idx]), 2).item() / data.df.shape[0] for idx in range(len(true_answers))]        
        avg_error = np.mean(errors)
        max_error = np.max(errors)
        return (avg_error, max_error)    
    
    
class ddrp_SOR:
    def __init__(self, target_marginals, rho, rounds, scoreFunction, 
                 gamma = 0.1, alpha = 0.5, weights = 'uniform', candidate_type = 'all_remeasure', initialization = 1, annealing = False, eta = 2.0):
        """
        Instantiates ddrp mechanism with optimized scores for residuals
        :param target_marginals: list of tuples of indices
        :param rho: scalar; privacy budget
        :param rounds: int
        :param scoreFunction: 'worst_residual', 'worst_marginal', or 'expected_residual_L1'
        :param gamma: scalar; proportion of budget spent on initialization
        :param alpha: scalar; proportion of budget spent on selection each round
        :param weights: 'uniform', 'upward_closure'
        :param candidates: 'frontier', 'frontier_remeasure', 'all_remeasure'
        :param initialization: 1, 2, 3, ...
        :param annealing: boolean
        :param eta: float; annealing parameter
        """
        
        self.target_marginals = target_marginals
        #self.target_workload = VStack([getMarginalWorkload(tup) for tup in self.target_marginals])
        self.target_closure = downward_closure(self.target_marginals)
        self.rounds = rounds
        self.rho = rho
        self.gamma = gamma
        self.rho_init = self.rho * self.gamma
        self.rho_round = (self.rho - self.rho_init) / self.rounds
        self.rho_used = 0
        self.alpha = alpha
        self.score = scoreFunction
        self.weights = weights
        self.candidate_type = candidate_type
        self.initialization = initialization
        self.annealing = annealing
        self.eta = eta
        
        
    def run(self, data):
        ## get all 0D & 1D residual queries 
        init_idx = [tup for tup in self.target_closure if len(tup) < self.initialization]
        
        ## create M, y, sigmas
        M = VStack([ResidualWorkload2(tup, data.domain) for tup in init_idx])
        y, init_sigmas = M.getAnswers(data, rho = self.rho_init/len(M.workloads), return_sigma = True)
        sigmas = {init_idx[i] : init_sigmas[i] for i in range(len(init_idx))}
        self.rho_used += self.rho_init
        # print(self.calcTargetWorkloadError(data, M, y))
        
        ## initial candidates and scores
        M_idx = set([wkload.cols for wkload in M.workloads])
        C_idx = initialCandidates(self.candidate_type, 
                                  M_idx, 
                                  self.target_closure)
        # print('Candidates:' + str(C_idx))
        weights = { cand : weightFunc(self.weights, cand, self.target_closure) for cand in C_idx }
        M_plus = M.pinv()
        scores = { cand : weights[cand] * scoreFunc(cand, data, M_plus, y, self.score, (1 - self.alpha) * self.rho_round, sigmas) for cand in C_idx }
        
        for t in tqdm(range(self.rounds)):
            # get sensitivity
            sensitivity = max(weights.values()) * sensitivityFunc(C_idx, data.domain, self.score)
            # print(C_idx)
            # print(scores.round(5))
            
            # run exp mechanism and measure selected workload
            c_star = exponential(list(scores.keys()), np.array(list(scores.values())), sensitivity, (self.alpha * self.rho_round * 8) ** 0.5)
            # print('Measuring:' + str(c_star))
            c_star_wkload = ResidualWorkload2(c_star, data.domain)
            c_star_answer, c_star_sigma = c_star_wkload.getAnswers(data, rho = (1 - self.alpha) * self.rho_round, return_sigma = True)
            # print(c_star_answer)
            self.rho_used += self.rho_round
            
            if self.annealing:
                c_star_model_old = (c_star_wkload @ M_plus) @ y
            
            ## update M, y, candidates, sigmas
            if c_star in sigmas.keys():
                c_star_index = [wkload.cols for wkload in M.workloads].index(c_star)
                # print('old y:' + str(y[c_star_index]))
                # print('old sigma:' + str(sigmas[c_star]))
                new_y_num = y[c_star_index]/(sigmas[c_star] ** 2) + c_star_answer/ (c_star_sigma ** 2)
                new_y_den = (sigmas[c_star] ** -2) + (c_star_sigma ** -2)
                y[c_star_index] = new_y_num/new_y_den
                sigmas[c_star] = new_y_den ** -0.5
                # print('new y:' + str(y[c_star_index]))
                # print('new sigmas:' + str(sigmas[c_star]))
            else:
                M = M.append(c_star_wkload)
                y.append(c_star_answer)
                sigmas[c_star] = c_star_sigma
                
            C_idx, new_idx = updateCandidateSet(candidates = C_idx, 
                                       new = c_star, 
                                       M = M, 
                                       target_closure = self.target_closure,
                                       candidate_type = self.candidate_type,
                                       return_new = True)
            # print('Candidates:' + str(C_idx))
            # print('# of Measurements:' + str(len(M.workloads)))
            # print('# of Candidates:' + str(len(C_idx)))
            
            # print(self.calcTargetWorkloadError(data, M, y))
            
            weights = { cand : weightFunc(self.weights, cand, self.target_closure) for cand in C_idx }
            M_plus = M.pinv()
            
            measure_all_scores = False
            if self.annealing:
                if self.rho - self.rho_used == 0:
                    break
                c_star_model_new = (c_star_wkload @ M_plus) @ y
                expected_improvement = sigmas[c_star] * ((2/np.pi) ** 0.5) * max(1, np.prod([(2 ** 0.5) * (data.domain[col] - 1) for col in c_star]))
                if torch.linalg.norm(c_star_model_new - c_star_model_old, 1).item() < expected_improvement:
                    # print('Increasing budget')
                    self.rho_round = self.eta * self.rho_round
                    measure_all_scores = True
                if self.rho - self.rho_used < 2 * self.rho_round:
                    self.rho_round = self.rho - self.rho_used
                    measure_all_scores = True
            
            if measure_all_scores:
                scores = { cand : weights[cand] * scoreFunc(cand, data, M_plus, y, self.score, (1 - self.alpha) * self.rho_round, sigmas) for cand in C_idx }

            else:
                if self.candidate_type == 'frontier':
                    scores[c_star] = None
                else:
                    scores[c_star] = weights[c_star] * scoreFunc(c_star, data, M_plus, y, self.score, (1 - self.alpha) * self.rho_round, sigmas)
                for cand in new_idx:
                    scores[cand] = weights[cand] * scoreFunc(cand, data, M_plus, y, self.score, (1 - self.alpha) * self.rho_round, sigmas)
            
        return M, y
    
    def getTargetWorkload(self, data):
        return VStack([MarginalWorkload(tup, data.domain) for tup in self.target_marginals])
    
    def calcTargetWorkloadError(self, data, M, y):
        Q = self.getTargetWorkload(data)
        QM_plus = Q @ M.pinv()
        inferred = QM_plus @ y
        true_answers = Q.getAnswers(data, sigma = 0)
        errors = [torch.linalg.vector_norm((inferred[idx] - true_answers[idx]), 2).item() / data.df.shape[0] for idx in range(len(true_answers))]        
        avg_error = np.mean(errors)
        max_error = np.max(errors)
        return (avg_error, max_error)   
    

    
class ddrp_np:
    def __init__(self, target_marginals, rounds, scoreFunction, 
                 weights = 'uniform', candidate_type = 'all_remeasure', initialization = 1):
        """
        Instantiates ddrp mechanism
        :param target_marginals: list of tuples of indices
        :param rounds: int
        :param scoreFunction: 'worst_residual', 'worst_marginal', or 'expected_residual_L1'
        :param weights: 'uniform', 'upward_closure'
        :param candidates: 'frontier', 'frontier_remeasure', 'all_remeasure'
        :param initialization: 1, 2, 3, ...
        """
        
        self.target_marginals = target_marginals
        #self.target_workload = VStack([getMarginalWorkload(tup) for tup in self.target_marginals])
        self.target_closure = downward_closure(self.target_marginals)
        self.rounds = rounds
        self.score = scoreFunction
        self.weights = weights
        self.candidate_type = candidate_type
        self.initialization = initialization
        
    def run(self, data):
        ## get all 0D & 1D residual queries 
        init_idx = [tup for tup in self.target_closure if len(tup) < self.initialization]
        
        ## create M, y, sigmas
        M = VStack([ResidualWorkload2(tup, data.domain) for tup in init_idx])
        y = M.getAnswers(data, sigma = 0)
        
        ## initial candidates
        M_idx = set([wkload.cols for wkload in M.workloads])
        C_idx = initialCandidates(self.candidate_type, 
                                  M_idx, 
                                  self.target_closure)
        # print('Candidates:' + str(C_idx)) 
        
        for t in tqdm(range(self.rounds)):
            # get scores and sensitivity
            weights = { cand : weightFunc(self.weights, cand, self.target_closure) for cand in C_idx }
            M_plus = M.pinv()
            scores = np.array([weights[cand] * scoreFunc(cand, data, M_plus, y, self.score, 0, {}) for cand in C_idx])
            # print(C_idx)
            # print(scores.round(5))
            
            c_star = C_idx[np.argmax(scores)]
            # print('Measuring:' + str(c_star))
            c_star_wkload = ResidualWorkload2(c_star, data.domain)
            c_star_answer = c_star_wkload.getAnswers(data, sigma = 0)
            # print(c_star_answer)

            M = M.append(c_star_wkload)
            y.append(c_star_answer)
            C_idx = updateCandidateSet(candidates = C_idx, 
                                       new = c_star, 
                                       M = M, 
                                       target_closure = self.target_closure,
                                       candidate_type = self.candidate_type)
            # print('Candidates:' + str(C_idx))
            # print('# of Measurements:' + str(len(M.workloads)))
            # print('# of Candidates:' + str(len(C_idx)))
            
            # print(self.calcTargetWorkloadError(data, M, y))
            
        return M, y
    
    def getTargetWorkload(self, data):
        return VStack([MarginalWorkload(tup, data.domain) for tup in self.target_marginals])
    
    def calcTargetWorkloadError(self, data, M, y):
        Q = self.getTargetWorkload(data)
        QM_plus = Q @ M.pinv()
        inferred = QM_plus @ y
        true_answers = Q.getAnswers(data, sigma = 0)
        errors = [torch.linalg.vector_norm((inferred[idx] - true_answers[idx]), 2).item() / data.df.shape[0] for idx in range(len(true_answers))]        
        avg_error = np.mean(errors)
        max_error = np.max(errors)
        return (avg_error, max_error)