import sys
sys.path.append('/home/marshal/Downloads/recourse-adaptive-preference/')

import autograd.numpy as np
import scipy.stats as stats
import heapq
import torch
from methods.reup.chebysev import chebysev_center, sdp_cost
from methods.reup import bayesian_utils, bayesian_inference
      
"""
search x_i and x_j such that the entropy given p(A)--> Wishart distribution is maximum
input:
Sigma: scale matrix --> PSD matrix (p by p, p = dim(x))
m : degree of freedom --> float -->  positive
n_samples : number of samples --> int --> positive
p : dimensionality of x
feasible_set : set of accepted applicants
x_0 : rejected applicant
epsilon : tolerance rate

output:
x_i, x_j in feasible_set, and the entropy given by x_i and x_j
"""
def max_entropy_search(A_0, Sigma, m, feasible_set, x_0, history, size=50):
    current_entropy = - np.inf

    for i in range(len(feasible_set) - 1):
        for j in range(i + 1, len(feasible_set)):
            
            check_history = (i, j) in history
            
            if not check_history:
                x_i = feasible_set[i].reshape(-1, 1)
                x_j = feasible_set[j].reshape(-1, 1)
                
                #compute entropy
                entropy_ij = bayesian_utils.entropy_McKay(x_i, x_j, x_0, Sigma, m)

                # compare the entropy
                if current_entropy < entropy_ij:
                    current_entropy = entropy_ij
                    M_ij = bayesian_utils.compute_M(x_i, x_j, x_0)
                    objective = np.trace(A_0 @ M_ij)
                    R_ij = 1 if objective <= 0 else -1
                    result = ((i, j), entropy_ij, R_ij, M_ij)

    return result

def sampling_max_entropy_search(A_0, Sigma, m, feasible_set, x_0, history, size=50, cost="l1-norm", kappa=100000000000):
    current_entropy = - np.inf

    full_index = np.arange(feasible_set.shape[0])
    sampled_index = np.random.choice(full_index, size=100)
    sampled_index = np.unique(sampled_index)
    sampled_index = sampled_index[:size]
    
    for i in range(len(sampled_index) - 1):
        for j in range(i + 1, len(sampled_index)):
            
            check_history = (sampled_index[i], sampled_index[j]) in history
            
            if not check_history:
                x_i = feasible_set[sampled_index[i]].reshape(-1, 1)
                x_j = feasible_set[sampled_index[j]].reshape(-1, 1)
                
                #compute entropy
                entropy_ij = bayesian_utils.entropy_McKay(x_i, x_j, x_0, Sigma, m)

                # compare the entropy
                if current_entropy < entropy_ij:
                    current_entropy = entropy_ij
                    M_ij = bayesian_utils.compute_M(x_i, x_j, x_0)
                    
                    if cost == "mahalanobis":
                        objective = np.trace(A_0 @ M_ij)
                    else:
                        l_i = np.linalg.norm((x_i - x_0), ord=1)
                        l_j = np.linalg.norm((x_j - x_0), ord=1)
                        objective = l_i - l_j
                    
                    if kappa > 9999:
                        R_ij = 1 if objective <= 0 else -1
                    else:
                        prob = 1 / (1 + np.exp(-kappa * objective))
                        R_ij = np.random.choice([1, -1], p=[1 - prob, prob]).item()

                    result = ((sampled_index[i], sampled_index[j]), entropy_ij, R_ij, M_ij)

    return result

def random_search(A_0, x_0, feasible_set, history, size):
    full_index = np.arange(feasible_set.shape[0])
    sampled_index = np.random.choice(full_index, size=100)
    sampled_index = np.unique(sampled_index)
    sampled_index = sampled_index[:size]

    check_history = True
    
    while check_history:
        i, j = np.random.choice(sampled_index, size=2)
        check_history = (i, j) in history or (j, i) in history

    x_i, x_j = feasible_set[i].reshape(-1, 1), feasible_set[j].reshape(-1, 1)
    entropy_ij = None
    M_ij = bayesian_utils.compute_M(x_i, x_j, x_0)
    objective = np.trace(A_0 @ M_ij)
    R_ij = 1 if objective <= 0 else -1
    result = ((i, j), entropy_ij, R_ij, M_ij)

    return result


def bayesian_PE(A_0, Sigma, m, x_0, d, feasible_set, sessions, iterations, lr, tau, size=50):
    log_dict = {}
    lst_ind = []
    lst_responses = []
    lst_pos_Sigma = []
    lst_pos_m = []

    #data-dim
    d = x_0.shape[0]

    #initialize the parameter of prior and posterior
    prior_Sigma = Sigma
    prior_m = m
    post_Sigma, post_m = None, None

    for s in range(sessions):
        #create set_m
        set_m = np.arange(d, prior_m + 1)
        
        #result = random_search(A_0, x_0, feasible_set, lst_ind, size)
        result = sampling_max_entropy_search(A_0, 
                                            prior_Sigma, 
                                            prior_m, 
                                            feasible_set, 
                                            x_0, 
                                            lst_ind, 
                                            size)
        lst_ind.append(result[0])
        
        # x_0's reponse given x_i and x_j
        R_ij = result[2]
        M_ij = result[3]

        # posterior inference for Sigma and m
        post_Sigma, post_m, losses = bayesian_inference.posterior_inference(
                                        set_m, 
                                        R_ij, 
                                        M_ij, 
                                        prior_Sigma, 
                                        prior_m, 
                                        tau, 
                                        d, 
                                        iterations, 
                                        lr)

        #update the prior from the obtained posterior
        prior_Sigma = post_Sigma
        prior_m = post_m

        #logging
        lst_responses.append(R_ij)
        lst_pos_Sigma.append(post_Sigma)
        lst_pos_m.append(post_m)
        
        log_s = {'iterations': iterations, 
                'losses': losses,
                }
        log_dict[s] = log_s

    log_dict['lst_ind'] = lst_ind
    log_dict['lst_responses'] = lst_responses
    log_dict['lst_Sigma'] = lst_pos_Sigma
    log_dict['lst_m'] = lst_pos_m
    
    return post_Sigma, post_m, log_dict

def bayesian_mean_rank(pos_Sigma, pos_m, data, x_0, A_0, top_k):
    A_samples = stats.wishart.rvs(pos_m, pos_Sigma, size=1000)
    lst_mean_rank = []

    for A in A_samples:
        mean_rank = compute_mean_rank(data, x_0, A_0, A, top_k)
        lst_mean_rank.append(mean_rank)
    
    lst_mean_rank = np.array(lst_mean_rank)

    return lst_mean_rank

def compute_mean_rank(data, x_0, A_0, A_opt, top_k, cost="mahalanobis"):
    N = data.shape[0]
    
    r_min = (top_k - 1) * top_k / 2
    r_max = (2 * N - top_k - 1) * top_k / 2 

    s = []
    
    for i in range(N):
        x_i = data[i].reshape(-1, 1)
        if cost == "mahalanobis":
            dist = bayesian_utils.compute_mahalanobis(x_i, x_0, A_0)
        else:
            dist = bayesian_utils.l1_norm_diag(x_i, x_0, A_0, diag=False)
            #dist = np.linalg.norm(x_i - x_0, ord=1)
        s.append((dist, i))

    s.sort()
    d_rank = {}
    
    for i in range(N):
        d_rank[s[i][1]] = i

    d_opt = {}
    for i in range(N):
        x_i = data[i].reshape(-1, 1)
        if cost == "mahalanobis":
            d_opt[i] = bayesian_utils.compute_mahalanobis(x_i, x_0, A_opt)
        else:
            d_opt[i] = bayesian_utils.l1_norm_diag(x_i, x_0, A_opt)
            #d_opt[i] = np.linalg.norm(x_i - x_0, ord=1)
    d_opt = dict(sorted(d_opt.items(), key=lambda item: item[1]))
    keys = list(d_opt.keys())
    sum_rank = 0
    
    for i in keys[:top_k]:
        sum_rank += d_rank[i]
    
    return (sum_rank - r_min) / r_max

#--------------------------------------------------------------------------------------------

def compute_M(x_0, x_i, x_j):
    M = np.outer(x_i, x_i) - np.outer(x_j, x_j) + np.outer(x_j - x_i, x_0) + np.outer(x_0, x_j - x_i)
    return M

def exhaustive_search(A_opt, x_0, data):
    cur = np.inf

    for i in range(len(data) - 1):
        for j in range(i + 1, len(data)):
            M_ij = compute_M(x_0, data[i], data[j])
            obj = np.abs(np.sum(np.multiply(A_opt, M_ij))) / np.linalg.norm(M_ij)
            if obj < cur:
                cur = obj
                res = M_ij, (data[i], data[j])
    return res

def exhaustive_search_k(A_opt, x_0, data, k):
    obj_l = []
    for i in range(len(data) - 1):
        for j in range(i + 1, len(data)):
            M_ij = compute_M(x_0, data[i], data[j])
            obj = np.abs(np.sum(np.multiply(A_opt, M_ij))) / np.linalg.norm(M_ij)
            obj_l.append(obj)

    obj_l.sort()
    print(obj_l[:k])

def similar_cost_heuristics(A_opt, x_0, data, A_0, epsilon, prev):
    l = data.shape[0]
    s = np.zeros(l)
    d = {}
    for i in range(l):
        s[i] = (data[i] - x_0).T @ A_opt @ (data[i] - x_0)
        d[i] = s[i]
    
    d_sorted = dict(sorted(d.items(), key=lambda item: item[1]))
    d_list = list(d_sorted.keys())

    cur = np.inf
    d_obj = {}
    for i in range(l - 1):
        M_ij = compute_M(x_0, data[d_list[i]], data[d_list[i + 1]])
        obj = np.abs(np.sum(np.multiply(A_opt, M_ij))) / np.linalg.norm(M_ij) 
        d_obj[obj] = (d_list[i], d_list[i + 1])

    d_obj_sorted = dict(sorted(d_obj.items()))
    for value in d_obj_sorted:
        if d_obj_sorted[value] not in prev:
            M_ij = compute_M(x_0, data[d_obj_sorted[value][0]], data[d_obj_sorted[value][1]])
            obj = np.sum(np.multiply(A_0, M_ij))
            res = M_ij if obj <= epsilon else -M_ij
            return res, d_obj_sorted[value], (data[d_obj_sorted[value][0]], data[d_obj_sorted[value][1]])
    
def similar_cost_heuristics_k(A_opt, x_0, data, A_0, epsilon, prev, k=3):
    l = data.shape[0]
    s = np.zeros(l)
    d = {}
    for i in range(l):
        s[i] = (data[i] - x_0).T @ A_opt @ (data[i] - x_0)
        d[i] = s[i]

    d_sorted = dict(sorted(d.items(), key=lambda item: item[1]))
    d_list = list(d_sorted.keys())

    cur_max_l  = []
    for i in range(l - k):
        cur_sum_max, cur_max = -np.inf, (np.inf, i, [i + m for m in range(k)])
        for j in range(k):
            cur, cur_sum = i + j, 0
            for l in range(k):
                if i + l != cur:
                    M_ij = compute_M(x_0, data[d_list[cur]], data[d_list[i + l]])
                    obj = np.abs(np.sum(np.multiply(A_opt, M_ij))) / np.linalg.norm(M_ij)
                    cur_sum += obj
                
            if cur_sum > cur_sum_max:
                cur_sum_max = cur_sum
                cur_max = (cur_sum, i + j, [i + m for m in range(k)])
        cur_max_l.append(cur_max)

    min_l = sorted(cur_max_l)
    for i in range(len(min_l)):
        cur, cur_l = min_l[i][1], min_l[i][2]
        if cur_l not in prev:
            cost_l = [(data[d_list[i]] - x_0).T @ A_0 @ (data[d_list[i]] - x_0) for i in cur_l]
            min_idx = cur_l[np.argmin(cost_l)]
            M_ij_l = []
            for i in cur_l:
                if i != min_idx:
                    M_ij = compute_M(x_0, data[d_list[min_idx]], data[d_list[i]])
                    obj = np.sum(np.multiply(A_0, M_ij))
                    res = M_ij if obj <= epsilon else -M_ij
                    M_ij_l.append(res)

            return M_ij_l, cur_l

def similar_cost_heuristics_kpairs(A_opt, x_0, data, k):
        """
        :type nums: List[int]
        :type k: int
        :rtype: int
        """
        l = data.shape[0]
        s = np.zeros(l)
        d = {}
        for i in range(l):
            s[i] = (data[i] - x_0).T @ A_opt @ (data[i] - x_0)
            d[i] = s[i]
        
        d_sorted = dict(sorted(d.items(), key=lambda item: item[1]))
        d_list = list(d_sorted.keys())

        s = np.sort(s)
        heap = []
        for i in range(1, len(s) - 1):
            M_ij = compute_M(x_0, data[d_list[0]], data[d_list[i + 1]])
        
        for i in range(len(s) - 1):
            M_ij = compute_M(x_0, data[d_list[i]], data[d_list[i + 1]])
            obj = np.abs(np.sum(np.multiply(A_opt, M_ij))) / np.linalg.norm(M_ij)
            heap.append((obj, i, i + 1))

        heapq.heapify(heap)

        for _ in range(k):
            d, root, nei = heapq.heappop(heap)
            if nei + 1 < len(s):
                M_ij = compute_M(x_0, data[d_list[nei + 1]], data[d_list[root]])
                obj = np.abs(np.sum(np.multiply(A_opt, M_ij))) / np.linalg.norm(M_ij)
                heapq.heappush(heap, (obj, root, nei + 1))
        
        return d

def mean_rank(data, x_0, A_0, A_opt, top_k):
    N = data.shape[0]
    
    r_min = (top_k - 1) * top_k / 2
    r_max = (2 * N - top_k - 1) * top_k / 2 

    s = []
    
    for i in range(N):
        dist = (data[i] - x_0).T @ A_0 @ (data[i] - x_0)
        s.append((dist, i))

    s.sort()
    d_rank = {}
    
    for i in range(N):
        d_rank[s[i][1]] = i

    d_opt = {}
    for i in range(N):
        d_opt[i] = (data[i] - x_0).T @ A_opt @ (data[i] - x_0)
    
    d_opt = dict(sorted(d_opt.items(), key=lambda item: item[1]))
    keys = list(d_opt.keys())
    sum_rank = 0
    
    for i in keys[:top_k]:
        sum_rank += d_rank[i]
    
    return (sum_rank - r_min) / r_max

def question_correction_gd(A_opt, A_0, x_0, x_i_opt, x_j_opt, alpha, max_iter, epsilon):
    A_opt = torch.tensor(A_opt, requires_grad=False)
    A_0 = torch.tensor(A_0, requires_grad=False)
    x_0 = torch.tensor(x_0, requires_grad=False)
    x_i = torch.tensor(x_i_opt, requires_grad=True)
    x_j = torch.tensor(x_j_opt, requires_grad=True)

    optimizer = torch.optim.SGD([x_i, x_j], lr=0.01, momentum=0.9)
    
    min_ = np.inf
    for i in range(max_iter):
        # Compute objective
        M_ij = torch.outer(x_i, x_i) - torch.outer(x_j, x_j) + torch.outer(x_j - x_i, x_0) + torch.outer(x_0, x_j - x_i)
        obj = (torch.sum((torch.multiply(M_ij, A_opt))) / torch.norm(M_ij)) ** 2
        if i == 0:
            print(torch.sqrt(obj))

        # Optimizer
        obj.backward()
        optimizer.step()

        if obj < min_:
            min_ = obj
            M_opt = M_ij
            pair = (x_i, x_j)
    
    obj = torch.sum(torch.multiply(A_0, M_ij))
    M_ij = M_ij if obj <= epsilon else -M_ij
    # M_ij = M_ij / torch.norm(M_ij)
    
    return M_ij.detach().numpy()
         
def question_correction_gd_k(A_opt, A_0, x_0, x_i_opt, x_j_opt, alpha, max_iter, epsilon):
    A_opt = torch.tensor(A_opt, requires_grad=False)
    A_0 = torch.tensor(A_0, requires_grad=False)
    x_0 = torch.tensor(x_0, requires_grad=False)
    x_i = torch.tensor(x_i_opt, requires_grad=True)
    x_j = torch.tensor(x_j_opt, requires_grad=True)

    optimizer = torch.optim.SGD([x_i, x_j], lr=0.01,
momentum=0.9)

    min_ = np.inf
    for i in range(max_iter):
        # Compute objective
        M_ij = torch.outer(x_i, x_i) - torch.outer(x_j, x_j) + torch.outer(x_j - x_i, x_0) + torch.outer(
_0, x_j - x_i)
        obj = (torch.sum((torch.multiply(M_ij, A_opt))) / torch.norm(M_ij)) ** 2

        # Optimizer
        obj.backward()
        optimizer.step()

        if obj < min_:
            min_ = obj
            M_opt = M_ij
            pair = (x_i, x_j)
    
    obj = torch.sum(torch.multiply(A_0, M_ij))
    M_ij = M_ij if obj <= epsilon else -M_ij
    # M_ij = M_ij / torch.norm(M_ij)

    return M_ij.detach().numpy()

def find_q(x_0, data, T, A_0, epsilon, cost_correction):
    """Find the set of constraints after T questions

    Parameters:
        x_0: input instance
        data: training data
        T: number of questions
        epsilon: parameter

    Returns:
        P: feasible set
    """
    d = x_0.shape[0]
    P, prev, rank_l = [], [], []
    A_opt_l = []

    for i in range(T):
        # Solve chebysev
        init = True if i == 0 else False
        radius, A_opt = chebysev_center(d, P, epsilon, init)
        A_opt_l.append(A_opt)
        #rank = mean_rank(data, x_0, A_0, A_opt, 5)
        #rank_l.append(rank)

        # 1 question
        M_ij, pair_value, pair = similar_cost_heuristics(A_opt, x_0, data, A_0, epsilon, prev)
        prev.append(pair_value)
        P.append(M_ij)
        
        # k questions
        #M_ij_l, pairs = similar_cost_heuristics_k(A_opt, x_0, data, A_0, epsilon, prev, k=4)
        #prev.append(pairs)
        #P += M_ij_l
    
    radius, A_opt = chebysev_center(d, P, epsilon, False)
    A_opt_l.append(A_opt)

    return P, A_opt, rank_l, (A_opt_l[1:]) 


if __name__ == '__main__':
    A = np.random.rand(2, 2)
    A = A @ A.T
    x_0 = np.random.rand(2)
    x_init = np.array([0, 0])
    data = np.random.rand(100, 2)

    P = find_q(x_0, data, 10, A, epsilon=1e-3, cost_correction=True)
